From 279348d7860cc3ba45a80c86f3d2c9194972db53 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 8 Jun 2021 10:32:05 +0000 Subject: [PATCH] move process utt to collator --- deepspeech/exps/deepspeech2/model.py | 2 +- deepspeech/io/collator.py | 117 ++++++++++++++++++++++++- deepspeech/io/dataset.py | 82 +---------------- examples/tiny/s0/conf/deepspeech2.yaml | 4 +- 4 files changed, 120 insertions(+), 85 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 468bc652..50ff3c17 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) - collate_fn = SpeechCollator(keep_transcription_text=False) + collate_fn = SpeechCollator(config, keep_transcription_text=False) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 3bec9875..d725b0b1 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -16,14 +16,22 @@ import numpy as np from deepspeech.frontend.utility import IGNORE_ID from deepspeech.io.utility import pad_sequence from deepspeech.utils.log import Log +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.normalizer import FeatureNormalizer +from deepspeech.frontend.speech import SpeechSegment +import io +import time __all__ = ["SpeechCollator"] logger = Log(__name__).getlog() +# namedtupe need global for pickle. +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) class SpeechCollator(): - def __init__(self, keep_transcription_text=True): + def __init__(self, config, keep_transcription_text=True): """ Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one bach. @@ -32,6 +40,112 @@ class SpeechCollator(): """ self._keep_transcription_text = keep_transcription_text + if isinstance(config.data.augmentation_config, (str, bytes)): + if config.data.augmentation_config: + aug_file = io.open( + config.data.augmentation_config, mode='r', encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.data.augmentation_config + assert isinstance(aug_file, io.StringIO) + + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), + random_seed=config.data.random_seed) + + self._normalizer = FeatureNormalizer( + config.data.mean_std_filepath) if config.data.mean_std_filepath else None + + self._stride_ms = config.data.stride_ms + self._target_sample_rate = config.data.target_sample_rate + + self._speech_featurizer = SpeechFeaturizer( + unit_type=config.data.unit_type, + vocab_filepath=config.data.vocab_filepath, + spm_model_prefix=config.data.spm_model_prefix, + specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delta_delta, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + dither=config.data.dither) + + def _parse_tar(self, file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + def _subfile_from_tar(self, file): + """Get subfile object from tar. + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) + + def process_utterance(self, audio_file, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param transcript: Transcription text. + :type transcript: str + :return: Tuple of audio feature tensor and data of transcription part, + where transcription part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + start_time = time.time() + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), transcript) + else: + speech_segment = SpeechSegment.from_file(audio_file, transcript) + load_wav_time = time.time() - start_time + #logger.debug(f"load wav time: {load_wav_time}") + + # audio augment + start_time = time.time() + self._augmentation_pipeline.transform_audio(speech_segment) + audio_aug_time = time.time() - start_time + #logger.debug(f"audio augmentation time: {audio_aug_time}") + + start_time = time.time() + specgram, transcript_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + feature_time = time.time() - start_time + #logger.debug(f"audio & test feature time: {feature_time}") + + # specgram augment + start_time = time.time() + specgram = self._augmentation_pipeline.transform_feature(specgram) + feature_aug_time = time.time() - start_time + #logger.debug(f"audio feature augmentation time: {feature_aug_time}") + return specgram, transcript_part + def __call__(self, batch): """batch examples @@ -53,6 +167,7 @@ class SpeechCollator(): text_lens = [] utts = [] for utt, audio, text in batch: + audio, text = self.process_utterance(audio, text) #utt utts.append(utt) # audio diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index eaa57a4e..fc687902 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -34,9 +34,6 @@ __all__ = [ logger = Log(__name__).getlog() -# namedtupe need global for pickle. -TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) - class ManifestDataset(Dataset): @classmethod @@ -192,10 +189,6 @@ class ManifestDataset(Dataset): self._stride_ms = stride_ms self._target_sample_rate = target_sample_rate - self._normalizer = FeatureNormalizer( - mean_std_filepath) if mean_std_filepath else None - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=augmentation_config, random_seed=random_seed) self._speech_featurizer = SpeechFeaturizer( unit_type=unit_type, vocab_filepath=vocab_filepath, @@ -214,8 +207,6 @@ class ManifestDataset(Dataset): self._rng = np.random.RandomState(random_seed) self._keep_transcription_text = keep_transcription_text - # for caching tar files info - self._local_data = TarLocalData(tar2info={}, tar2object={}) # read manifest self._manifest = read_manifest( @@ -256,74 +247,7 @@ class ManifestDataset(Dataset): def stride_ms(self): return self._speech_featurizer.stride_ms - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - - def process_utterance(self, utt, audio_file, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param transcript: Transcription text. - :type transcript: str - :return: Tuple of audio feature tensor and data of transcription part, - where transcription part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - start_time = time.time() - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), transcript) - else: - speech_segment = SpeechSegment.from_file(audio_file, transcript) - load_wav_time = time.time() - start_time - #logger.debug(f"load wav time: {load_wav_time}") - - # audio augment - start_time = time.time() - self._augmentation_pipeline.transform_audio(speech_segment) - audio_aug_time = time.time() - start_time - #logger.debug(f"audio augmentation time: {audio_aug_time}") - - start_time = time.time() - specgram, transcript_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - if self._normalizer: - specgram = self._normalizer.apply(specgram) - feature_time = time.time() - start_time - #logger.debug(f"audio & test feature time: {feature_time}") - - # specgram augment - start_time = time.time() - specgram = self._augmentation_pipeline.transform_feature(specgram) - feature_aug_time = time.time() - start_time - #logger.debug(f"audio feature augmentation time: {feature_aug_time}") - return utt, specgram, transcript_part def _instance_reader_creator(self, manifest): """ @@ -336,8 +260,6 @@ class ManifestDataset(Dataset): def reader(): for instance in manifest: - # inst = self.process_utterance(instance["feat"], - # instance["text"]) inst = self.process_utterance(instance["utt"], instance["feat"], instance["text"]) yield inst @@ -349,6 +271,4 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["utt"], instance["feat"], - instance["text"]) - # return self.process_utterance(instance["feat"], instance["text"]) + return(instance["utt"], instance["feat"], instance["text"]) diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index dd9ce51f..aeb4f099 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -6,7 +6,7 @@ data: mean_std_filepath: data/mean_std.json vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.json - batch_size: 4 + batch_size: 2 min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 @@ -37,7 +37,7 @@ model: share_rnn_weights: True training: - n_epoch: 20 + n_epoch: 10 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06