diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index bcd66d19..679261cf 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -137,7 +137,7 @@ class DeepSpeech2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() config.defrost() - config.data.keep_transcription_text = False + config.collator.keep_transcription_text = False config.data.manifest = config.data.train_manifest train_dataset = ManifestDataset.from_config(config) @@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer): sortagrad=config.data.sortagrad, shuffle_method=config.data.shuffle_method) - collate_fn = SpeechCollator(config=config, keep_transcription_text=False) + collate_fn = SpeechCollator.from_config(config) self.train_loader = DataLoader( train_dataset, batch_sampler=batch_sampler, diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index e6761cb5..bcb8e3f4 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -104,50 +104,7 @@ class SpeechFeaturizer(object): speech_segment.transcript) return spec_feature, text_ids - @property - def vocab_size(self): - """Return the vocabulary size. - - Returns: - int: Vocabulary size. - """ - return self._text_featurizer.vocab_size - - @property - def vocab_list(self): - """Return the vocabulary in list. - Returns: - List[str]: - """ - return self._text_featurizer.vocab_list - - @property - def vocab_dict(self): - """Return the vocabulary in dict. - - Returns: - Dict[str, int]: - """ - return self._text_featurizer.vocab_dict - - @property - def feature_size(self): - """Return the audio feature size. - - Returns: - int: audio feature size. - """ - return self._audio_featurizer.feature_size - - @property - def stride_ms(self): - """time length in `ms` unit per frame - - Returns: - float: time(ms)/frame - """ - return self._audio_featurizer.stride_ms @property def text_feature(self): diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index b2dd9601..610104f9 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -82,7 +82,7 @@ def read_manifest( ] if all(conditions): manifest.append(json_data) - return manifest + return manifest, json_data["feat_shape"][-1] def rms_to_db(rms: float): diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 0f86b8e7..4efc69a0 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -22,6 +22,8 @@ from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment import io import time +from yacs.config import CfgNode +from typing import Optional from collections import namedtuple @@ -33,51 +35,134 @@ logger = Log(__name__).getlog() TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) class SpeechCollator(): - def __init__(self, config, keep_transcription_text=True): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + mean_std_filepath="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, # feature dither + keep_transcription_text=True + )) - if ``keep_transcription_text`` is False, text is token ids else is raw string. + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. """ - self._keep_transcription_text = keep_transcription_text + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.data + assert 'specgram_type' in config.collator + assert 'n_fft' in config.collator + assert config.collator - if isinstance(config.data.augmentation_config, (str, bytes)): - if config.data.augmentation_config: + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: aug_file = io.open( - config.data.augmentation_config, mode='r', encoding='utf8') + config.collator.augmentation_config, mode='r', encoding='utf8') else: aug_file = io.StringIO(initial_value='{}', newline='') else: - aug_file = config.data.augmentation_config + aug_file = config.collator.augmentation_config assert isinstance(aug_file, io.StringIO) + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.data.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text + ) + return speech_collator + + def __init__(self, aug_file, mean_std_filepath, + vocab_filepath, spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): + """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. + + if ``keep_transcription_text`` is False, text is token ids else is raw string. + """ + self._keep_transcription_text = keep_transcription_text + self._local_data = TarLocalData(tar2info={}, tar2object={}) self._augmentation_pipeline = AugmentationPipeline( augmentation_config=aug_file.read(), - random_seed=config.data.random_seed) + random_seed=random_seed) self._normalizer = FeatureNormalizer( - config.data.mean_std_filepath) if config.data.mean_std_filepath else None + mean_std_filepath) if mean_std_filepath else None - self._stride_ms = config.data.stride_ms - self._target_sample_rate = config.data.target_sample_rate + self._stride_ms = stride_ms + self._target_sample_rate = target_sample_rate self._speech_featurizer = SpeechFeaturizer( - unit_type=config.data.unit_type, - vocab_filepath=config.data.vocab_filepath, - spm_model_prefix=config.data.spm_model_prefix, - specgram_type=config.data.specgram_type, - feat_dim=config.data.feat_dim, - delta_delta=config.data.delta_delta, - stride_ms=config.data.stride_ms, - window_ms=config.data.window_ms, - n_fft=config.data.n_fft, - max_freq=config.data.max_freq, - target_sample_rate=config.data.target_sample_rate, - use_dB_normalization=config.data.use_dB_normalization, - target_dB=config.data.target_dB, - dither=config.data.dither) + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=stride_ms, + window_ms=window_ms, + n_fft=n_fft, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=dither) def _parse_tar(self, file): """Parse a tar file to get a tarfile object @@ -196,3 +281,28 @@ class SpeechCollator(): texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) return utts, padded_audios, audio_lens, padded_texts, text_lens + + @property + def vocab_size(self): + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + return self._speech_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._text_featurizer + self._speech_featurizer.text_feature + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + @property + def stride_ms(self): + return self._speech_featurizer.stride_ms diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index aa5b29ed..1e3bbcd3 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -55,20 +55,6 @@ class ManifestDataset(Dataset): min_output_len=0.0, max_output_input_ratio=float('inf'), min_output_input_ratio=0.0, - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - raw_wav=True, # use raw_wav or kaldi feature - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - dither=1.0, # feature dither - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False, batch_size=32, # batch size num_workers=0, # data loader workers sortagrad=False, # sorted in first epoch when True @@ -116,21 +102,19 @@ class ManifestDataset(Dataset): min_output_len=config.data.min_output_len, max_output_input_ratio=config.data.max_output_input_ratio, min_output_input_ratio=config.data.min_output_input_ratio, - stride_ms=config.data.stride_ms, - window_ms=config.data.window_ms, - n_fft=config.data.n_fft, - max_freq=config.data.max_freq, - target_sample_rate=config.data.target_sample_rate, - specgram_type=config.data.specgram_type, - feat_dim=config.data.feat_dim, - delta_delta=config.data.delta_delta, - dither=config.data.dither, - use_dB_normalization=config.data.use_dB_normalization, - target_dB=config.data.target_dB, - random_seed=config.data.random_seed, - keep_transcription_text=config.data.keep_transcription_text) + ) return dataset + + def _read_vocab(self, vocab_filepath): + """Load vocabulary from file.""" + vocab_lines = [] + with open(vocab_filepath, 'r', encoding='utf-8') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + return vocab_list + + def __init__(self, manifest_path, unit_type, @@ -143,20 +127,7 @@ class ManifestDataset(Dataset): max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), - min_output_input_ratio=0.0, - stride_ms=10.0, - window_ms=20.0, - n_fft=None, - max_freq=None, - target_sample_rate=16000, - specgram_type='linear', - feat_dim=None, - delta_delta=False, - dither=1.0, - use_dB_normalization=True, - target_dB=-20, - random_seed=0, - keep_transcription_text=False): + min_output_input_ratio=0.0): """Manifest Dataset Args: @@ -186,30 +157,11 @@ class ManifestDataset(Dataset): keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. """ super().__init__() - self._stride_ms = stride_ms - self._target_sample_rate = target_sample_rate - - self._speech_featurizer = SpeechFeaturizer( - unit_type=unit_type, - vocab_filepath=vocab_filepath, - spm_model_prefix=spm_model_prefix, - specgram_type=specgram_type, - feat_dim=feat_dim, - delta_delta=delta_delta, - stride_ms=stride_ms, - window_ms=window_ms, - n_fft=n_fft, - max_freq=max_freq, - target_sample_rate=target_sample_rate, - use_dB_normalization=use_dB_normalization, - target_dB=target_dB, - dither=dither) - - self._rng = np.random.RandomState(random_seed) - self._keep_transcription_text = keep_transcription_text + + # self._rng = np.random.RandomState(random_seed) # read manifest - self._manifest = read_manifest( + self._manifest, self._feature_size = read_manifest( manifest_path=manifest_path, max_input_len=max_input_len, min_input_len=min_input_len, @@ -219,9 +171,59 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) + self._vocab_list = self._read_vocab(vocab_filepath) + @property def manifest(self): return self._manifest + + @property + def vocab_size(self): + """Return the vocabulary size. + + Returns: + int: Vocabulary size. + """ + return len(self._vocab_list) + + @property + def vocab_list(self): + """Return the vocabulary in list. + + Returns: + List[str]: + """ + return self._vocab_list + + @property + def vocab_dict(self): + """Return the vocabulary in dict. + + Returns: + Dict[str, int]: + """ + vocab_dict = dict( + [(token, idx) for (idx, token) in enumerate(self._vocab_list)]) + return vocab_dict + + @property + def feature_size(self): + """Return the audio feature size. + + Returns: + int: audio feature size. + """ + return self._feature_size + + @property + def stride_ms(self): + """time length in `ms` unit per frame + + Returns: + float: time(ms)/frame + """ + return self._audio_featurizer.stride_ms + def __len__(self): return len(self._manifest) diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index aeb4f099..eda7c3cb 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -4,9 +4,10 @@ data: dev_manifest: data/manifest.tiny test_manifest: data/manifest.tiny mean_std_filepath: data/mean_std.json + unit_type: char vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.json - batch_size: 2 + batch_size: 4 min_input_len: 0.0 max_input_len: 27.0 min_output_len: 0.0 @@ -28,6 +29,24 @@ data: sortagrad: True shuffle_method: batch_shuffle num_workers: 0 + +collator: + augmentation_config: conf/augmentation.json + random_seed: 0 + mean_std_filepath: data/mean_std.json + spm_model_prefix: + specgram_type: linear + feat_dim: + delta_delta: False + stride_ms: 10.0 + window_ms: 20.0 + n_fft: None + max_freq: None + target_sample_rate: 16000 + use_dB_normalization: True + target_dB: -20 + dither: 1.0 + keep_transcription_text: True model: num_conv_layers: 2 @@ -37,7 +56,7 @@ model: share_rnn_weights: True training: - n_epoch: 10 + n_epoch: 21 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06