From 043127b6fdc8b52867cbef1ccf45712cd664c632 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 29 Jun 2021 02:34:31 +0000 Subject: [PATCH] revise collator --- deepspeech/frontend/augmentor/augmentation.py | 6 +++--- .../frontend/augmentor/shift_perturb.py | 18 ++++++++-------- deepspeech/frontend/augmentor/spec_augment.py | 2 +- deepspeech/io/collator.py | 21 ++++++++++++------- examples/aishell/s0/conf/deepspeech2.yaml | 3 ++- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index 360fc215c..ac9957ebe 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -104,7 +104,7 @@ class AugmentationPipeline(): for augmentor, rate in zip(self._augmentors, self._rates): augmentor.randomize_parameters() - def randomize_parameters_feature_transform(self, audio): + def randomize_parameters_feature_transform(self, n_frames, n_bins): """Run the pre-processing pipeline for data augmentation. Note that this is an in-place transformation. @@ -112,8 +112,8 @@ class AugmentationPipeline(): :param audio_segment: Audio segment to process. :type audio_segment: AudioSegmenet|SpeechSegment """ - for augmentor, rate in zip(self._augmentors, self._rates): - augmentor.randomize_parameters(audio) + for augmentor, rate in zip(self._spec_augmentors, self._rates): + augmentor.randomize_parameters(n_frames, n_bins) def apply_audio_transform(self, audio_segment): """Run the pre-processing pipeline for data augmentation. diff --git a/deepspeech/frontend/augmentor/shift_perturb.py b/deepspeech/frontend/augmentor/shift_perturb.py index 8acfb5e54..cc91b402b 100644 --- a/deepspeech/frontend/augmentor/shift_perturb.py +++ b/deepspeech/frontend/augmentor/shift_perturb.py @@ -37,17 +37,17 @@ class ShiftPerturbAugmentor(AugmentorBase): def apply(self, audio_segment): audio_segment.shift(self.shift_ms) - def transform_audio(self, audio_segment, single): - """Shift audio. + # def transform_audio(self, audio_segment, single): + # """Shift audio. - Note that this is an in-place transformation. + # Note that this is an in-place transformation. - :param audio_segment: Audio segment to add effects to. - :type audio_segment: AudioSegmenet|SpeechSegment - """ - if(single): - self.randomize_parameters() - self.apply(audio_segment) + # :param audio_segment: Audio segment to add effects to. + # :type audio_segment: AudioSegmenet|SpeechSegment + # """ + # if(single): + # self.randomize_parameters() + # self.apply(audio_segment) # def transform_audio(self, audio_segment): diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 332c07095..f0e6a5ece 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -124,7 +124,7 @@ class SpecAugmentor(AugmentorBase): def time_warp(xs, W=40): raise NotImplementedError - def randomize_parameters(self, n_frame, n_bins): + def randomize_parameters(self, n_frames, n_bins): # n_bins = xs.shape[0] # n_frames = xs.shape[1] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index bfba3c55b..f105acc06 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -110,7 +110,8 @@ class SpeechCollator(): use_dB_normalization=config.collator.use_dB_normalization, target_dB=config.collator.target_dB, dither=config.collator.dither, - keep_transcription_text=config.collator.keep_transcription_text) + keep_transcription_text=config.collator.keep_transcription_text, + randomize_each_batch=config.collator.randomize_each_batch) return speech_collator def __init__( @@ -132,7 +133,8 @@ class SpeechCollator(): use_dB_normalization=True, target_dB=-20, dither=1.0, - keep_transcription_text=True): + keep_transcription_text=True, + randomize_each_batch=False): """SpeechCollator Collator Args: @@ -160,6 +162,7 @@ class SpeechCollator(): a user-defined shape) within one batch. """ self._keep_transcription_text = keep_transcription_text + self._randomize_each_batch = randomize_each_batch self._local_data = TarLocalData(tar2info={}, tar2object={}) self._augmentation_pipeline = AugmentationPipeline( @@ -170,6 +173,7 @@ class SpeechCollator(): self._stride_ms = stride_ms self._target_sample_rate = target_sample_rate + self._speech_featurizer = SpeechFeaturizer( unit_type=unit_type, @@ -224,10 +228,10 @@ class SpeechCollator(): return speech_segment def randomize_audio_parameters(self): - self._augmentation_pipeline.andomize_parameters_audio_transform() + self._augmentation_pipeline.randomize_parameters_audio_transform() - def randomize_feature_parameters(self, n_bins, n_frames): - self._augmentation_pipeline.andomize_parameters_feature_transform(n_bins, n_frames) + def randomize_feature_parameters(self, n_frames, n_bins): + self._augmentation_pipeline.randomize_parameters_feature_transform(n_frames, n_bins) def process_feature_and_transform(self, audio_file, transcript): """Load, augment, featurize and normalize for speech data. @@ -317,12 +321,15 @@ class SpeechCollator(): # print(len(batch)) self.randomize_audio_parameters() for utt, audio, text in batch: - if not self.config.randomize_each_batch: + if not self._randomize_each_batch: self.randomize_audio_parameters() audio, text = self.process_feature_and_transform(audio, text) #utt utts.append(utt) # audio + # print("---debug---") + # print(audio.shape) + audio=audio.T audios.append(audio) # [T, D] audio_lens.append(audio.shape[0]) # text @@ -350,7 +357,7 @@ class SpeechCollator(): n_bins=padded_audios.shape[2] self.randomize_feature_parameters(min(audio_lens), n_bins) for i in range(len(padded_audios)): - if not self.config.randomize_each_batch: + if not self._randomize_each_batch: self.randomize_feature_parameters(n_bins, audio_lens[i]) padded_audios[i] = self._augmentation_pipeline.apply_feature_transform(padded_audios[i]) diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 1004fde0e..1fe21a406 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -11,7 +11,8 @@ data: max_output_input_ratio: .inf collator: - batch_size: 64 # one gpu + batch_size: 32 #64 # one gpu + randomize_each_batch: False mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt