From 698d7a9bdb3de1a763ed8ba7a71b68241e3eea17 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Thu, 17 Jun 2021 07:16:52 +0000 Subject: [PATCH] move batch_size, work_nums, shuffle_method, sortagrad to collator --- deepspeech/exps/deepspeech2/config.py | 20 +++++------------ deepspeech/exps/deepspeech2/model.py | 18 +++++++-------- deepspeech/exps/u2/config.py | 6 ++++- .../frontend/featurizer/speech_featurizer.py | 10 --------- deepspeech/io/collator.py | 22 ------------------- examples/aishell/s0/conf/deepspeech2.yaml | 9 ++++---- examples/tiny/s0/conf/deepspeech2.yaml | 9 ++++---- 7 files changed, 29 insertions(+), 65 deletions(-) diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 1ce5346f6..faaff1aad 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -28,20 +28,6 @@ _C.data = CN( augmentation_config="", max_duration=float('inf'), min_duration=0.0, - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delat_delta=False, # 'mfcc', 'fbank' - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - batch_size=32, # batch size - num_workers=0, # data loader workers - sortagrad=False, # sorted in first epoch when True - shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) _C.model = CN( @@ -72,7 +58,11 @@ _C.collator =CN( use_dB_normalization=True, target_dB=-20, dither=1.0, # feature dither - keep_transcription_text=False + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' )) DeepSpeech2Model.params(_C.model) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 5833382a4..b54192dd3 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer): 'train_loss': float(loss), } msg += "train time: {:>.3f}s, ".format(iteration_time) - msg += "batch size: {}, ".format(self.config.data.batch_size) + msg += "batch size: {}, ".format(self.config.collator.batch_size) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) logger.info(msg) @@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer): if self.parallel: batch_sampler = SortagradDistributedBatchSampler( train_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) else: batch_sampler = SortagradBatchSampler( train_dataset, shuffle=True, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, drop_last=True, - sortagrad=config.data.sortagrad, - shuffle_method=config.data.shuffle_method) + sortagrad=config.collator.sortagrad, + shuffle_method=config.collator.shuffle_method) collate_fn = SpeechCollator.from_config(config) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=config.data.num_workers) + num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=config.data.batch_size, + batch_size=config.collator.batch_size, shuffle=False, drop_last=False, collate_fn=collate_fn) diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index 19080be76..42725c74f 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -26,7 +26,11 @@ _C.collator =CfgNode( dict( augmentation_config="", unit_type="char", - keep_transcription_text=False + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle" # 'batch_shuffle', 'instance_shuffle' )) _C.model = U2Model.params() diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 852d26c9a..0fbbc5648 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -151,13 +151,3 @@ class SpeechFeaturizer(object): TextFeaturizer: object. """ return self._text_featurizer - - - # @property - # def text_feature(self): - # """Return the text feature object. - - # Returns: - # TextFeaturizer: object. - # """ - # return self._text_featurizer diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 8b8575dbd..ac817a192 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -203,34 +203,22 @@ class SpeechCollator(): where transcription part could be token ids or text. :rtype: tuple of (2darray, list) """ - start_time = time.time() if isinstance(audio_file, str) and audio_file.startswith('tar:'): speech_segment = SpeechSegment.from_file( self._subfile_from_tar(audio_file), transcript) else: speech_segment = SpeechSegment.from_file(audio_file, transcript) - load_wav_time = time.time() - start_time - #logger.debug(f"load wav time: {load_wav_time}") # audio augment - start_time = time.time() self._augmentation_pipeline.transform_audio(speech_segment) - audio_aug_time = time.time() - start_time - #logger.debug(f"audio augmentation time: {audio_aug_time}") - start_time = time.time() specgram, transcript_part = self._speech_featurizer.featurize( speech_segment, self._keep_transcription_text) if self._normalizer: specgram = self._normalizer.apply(specgram) - feature_time = time.time() - start_time - #logger.debug(f"audio & test feature time: {feature_time}") # specgram augment - start_time = time.time() specgram = self._augmentation_pipeline.transform_feature(specgram) - feature_aug_time = time.time() - start_time - #logger.debug(f"audio feature augmentation time: {feature_aug_time}") return specgram, transcript_part def __call__(self, batch): @@ -283,16 +271,6 @@ class SpeechCollator(): return utts, padded_audios, audio_lens, padded_texts, text_lens - # @property - # def text_feature(self): - # return self._speech_featurizer.text_feature - - - # @property - # def stride_ms(self): - # return self._speech_featurizer.stride_ms - -########### @property def manifest(self): diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index e5ab8e046..54ce240e7 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -5,16 +5,13 @@ data: test_manifest: data/manifest.test mean_std_filepath: data/mean_std.json vocab_filepath: data/vocab.txt - batch_size: 64 # one gpu min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 0 + collator: augmentation_config: conf/augmentation.json @@ -32,6 +29,10 @@ collator: target_dB: -20 dither: 1.0 keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 64 # one gpu model: num_conv_layers: 2 diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6680e5686..434cf264c 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -6,16 +6,13 @@ data: mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt - batch_size: 4 min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 0 + collator: augmentation_config: conf/augmentation.json @@ -33,6 +30,10 @@ collator: target_dB: -20 dither: 1.0 keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 + batch_size: 4 model: num_conv_layers: 2