diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb index 019fcf393..ba50d8743 100644 --- a/.notebook/jit_infer.ipynb +++ b/.notebook/jit_infer.ipynb @@ -307,6 +307,8 @@ " max_freq=config.data.max_freq,\n", " target_sample_rate=config.data.target_sample_rate,\n", " specgram_type=config.data.specgram_type,\n", + " feat_dim=config.data.feat_dim,\n", + " delta_delta=config.data.delat_delta,\n", " use_dB_normalization=config.data.use_dB_normalization,\n", " target_dB=config.data.target_dB,\n", " random_seed=config.data.random_seed,\n", diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index 5948fbd48..7967a325a 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -98,6 +98,8 @@ def start_server(config, args): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index 5f72b1600..f0e803380 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -50,6 +50,8 @@ def start_server(config, args): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 3df9fb314..5f75c9d0f 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -56,6 +56,8 @@ def tune(config, args): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 1762aeadf..279a035fa 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -32,8 +32,10 @@ _C.data = CN( window_ms=20.0, # ms n_fft=None, # fft points max_freq=None, # None for samplerate/2 - specgram_type='linear', # 'linear', 'mfcc' - target_sample_rate=16000, # sample rate + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delat_delta=False, # 'mfcc', 'fbank' + target_sample_rate=16000, # target sample rate use_dB_normalization=True, target_dB=-20, random_seed=0, diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 13fe0dca5..d74edfffe 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -163,6 +163,8 @@ class DeepSpeech2Trainer(Trainer): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, @@ -183,6 +185,8 @@ class DeepSpeech2Trainer(Trainer): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, @@ -378,6 +382,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): max_freq=config.data.max_freq, target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=config.data.use_dB_normalization, target_dB=config.data.target_dB, random_seed=config.data.random_seed, diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 6aef47622..52a1b8f20 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -61,7 +61,9 @@ class AudioFeaturizer(object): use_dB_normalization=True, target_dB=-20): self._specgram_type = specgram_type + # mfcc and fbank using `feat_dim` self._feat_dim = feat_dim + # mfcc and fbank using `delta-delta` self._delta_delta = delta_delta self._stride_ms = stride_ms self._window_ms = window_ms @@ -130,25 +132,28 @@ class AudioFeaturizer(object): """Extract various audio features.""" if self._specgram_type == 'linear': return self._compute_linear_specgram( - samples, sample_rate, self._stride_ms, self._window_ms, - self._max_freq) + samples, + sample_rate, + stride_ms=self._stride_ms, + window_ms=self._window_ms, + max_freq=self._max_freq) elif self._specgram_type == 'mfcc': return self._compute_mfcc( samples, sample_rate, - self._stride_ms, - self._feat_dim, - self._window_ms, - self._max_freq, + feat_dim=self._feat_dim, + stride_ms=self._stride_ms, + window_ms=self._window_ms, + max_freq=self._max_freq, delta_delta=self._delta_delta) elif self._specgram_type == 'fbank': return self._compute_fbank( samples, sample_rate, - self._stride_ms, - self._feat_dim, - self._window_ms, - self._max_freq, + feat_dim=self._feat_dim, + stride_ms=self._stride_ms, + window_ms=self._window_ms, + max_freq=self._max_freq, delta_delta=self._delta_delta) else: raise ValueError("Unknown specgram_type %s. " @@ -323,10 +328,9 @@ class AudioFeaturizer(object): winstep=0.001 * stride_ms, nfilt=feat_dim, nfft=512, - lowfreq=max_freq, - highfreq=None, - preemph=0.97, - winfunc=lambda x: np.ones((x, ))) + lowfreq=0, + highfreq=max_freq, + preemph=0.97,) fbank_feat = np.transpose(fbank_feat) if delta_delta: fbank_feat = self._concat_delta_delta(fbank_feat) diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index e8f92798b..6530fc937 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -56,8 +56,8 @@ class SpeechFeaturizer(object): vocab_filepath, spm_model_prefix=None, specgram_type='linear', - feat_dim=13, - delta_delta=True, + feat_dim=None, + delta_delta=False, stride_ms=10.0, window_ms=20.0, n_fft=None, diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 13b404e86..8facacc05 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -43,6 +43,15 @@ class TextFeaturizer(object): self.sp = spm.SentencePieceProcessor() self.sp.Load(spm_model) + def tokenize(self, text): + if self.unit_type == 'char': + tokens = self.char_tokenize(text) + elif self.unit_type == 'word': + tokens = self.word_tokenize(text) + else: # spm + tokens = self.spm_tokenize(text) + return tokens + def featurize(self, text): """Convert text string to a list of token indices in char-level.Note that the token indexing order follows the given vocabulary file. @@ -52,13 +61,7 @@ class TextFeaturizer(object): :return: List of char-level token indices. :rtype: List[int] """ - if self.unit_type == 'char': - tokens = self.char_tokenize(text) - elif self.unit_type == 'word': - tokens = self.word_tokenize(text) - else: - tokens = self.spm_tokenize(text) - + tokens = self.tokenize(text) ids = [] for token in tokens: token = token if token in self._vocab_dict else self.unk diff --git a/deepspeech/io/__init__.py b/deepspeech/io/__init__.py index aa638179e..3446228fc 100644 --- a/deepspeech/io/__init__.py +++ b/deepspeech/io/__init__.py @@ -55,6 +55,8 @@ def create_dataloader(manifest_path, window_ms=window_ms, max_freq=max_freq, specgram_type=specgram_type, + feat_dim=config.data.feat_dim, + delta_delta=config.data.delat_delta, use_dB_normalization=use_dB_normalization, random_seed=random_seed, keep_transcription_text=keep_transcription_text) diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index c22e9d16d..149cf6bd2 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -51,6 +51,8 @@ class ManifestDataset(Dataset): max_freq=None, target_sample_rate=16000, specgram_type='linear', + feat_dim=None, + delta_delta=False, use_dB_normalization=True, target_dB=-20, random_seed=0, @@ -71,7 +73,9 @@ class ManifestDataset(Dataset): n_fft (int, optional): fft points for rfft. Defaults to None. max_freq (int, optional): max cut freq. Defaults to None. target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. - specgram_type (str, optional): 'linear' or 'mfcc'. Defaults to 'linear'. + specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. + delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. use_dB_normalization (bool, optional): do dB normalization. Defaults to True. target_dB (int, optional): target dB. Defaults to -20. random_seed (int, optional): for random generator. Defaults to 0. @@ -89,6 +93,8 @@ class ManifestDataset(Dataset): vocab_filepath=vocab_filepath, spm_model_prefix=spm_model_prefix, specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, stride_ms=stride_ms, window_ms=window_ms, n_fft=n_fft, diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index 9794da349..a54f80e5d 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -40,7 +40,9 @@ fi python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ --manifest_path="data/manifest.tiny.raw" \ --num_samples=64 \ ---specgram_type="linear" \ +--specgram_type="fbank" \ +--feat_dim=80 \ +--delta_delta=false \ --output_path="data/mean_std.npz" if [ $? -ne 0 ]; then diff --git a/utils/build_vocab.py b/utils/build_vocab.py index cbd3339d3..3ef566b12 100644 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -54,17 +54,13 @@ add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", args = parser.parse_args() -def count_manifest(counter, manifest_path): +def count_manifest(counter, text_feature, manifest_path): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: - if args.unit_type == 'char': - for char in line_json['text']: - counter.update(char) - elif args.unit_type == 'word': - for word in line_json['text'].split(): - counter.update(word) - -def read_text_manifest(fileobj, manifest_path): + line = text_feature.tokenize(line_json['text']) + counter.update(line) + +def dump_text_manifest(fileobj, manifest_path): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: fileobj.write(line_json['text'] + "\n") @@ -77,9 +73,11 @@ def main(): fout.write(UNK + '\n') # must be 1 if args.unit_type != 'spm': + text_feature = TextFeaturizer(args.unit_type, args.vocab_path) counter = Counter() + for manifest_path in args.manifest_paths: - count_manifest(counter, manifest_path) + count_manifest(counter, text_feature, manifest_path) count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) for char, count in count_sorted: @@ -93,7 +91,7 @@ def main(): fp = tempfile.NamedTemporaryFile(mode='w', delete=False) for manifest_path in args.manifest_paths: - read_text_manifest(fp, manifest_path) + dump_text_manifest(fp, manifest_path) fp.close() # train spm.SentencePieceTrainer.Train( @@ -108,20 +106,29 @@ def main(): # encode text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) - vocabs = set() + # vocabs = set() + # for manifest_path in args.manifest_paths: + # manifest_jsons = read_manifest(manifest_path) + # for line_json in manifest_jsons: + # line = line_json['text'] + # enc_line = text_feature.spm_tokenize(line) + # for code in enc_line: + # vocabs.add(code) + # #print(" ".join(enc_line)) + # vocabs_sorted = sorted(vocabs) + # for unit in vocabs_sorted: + # fout.write(unit + "\n") + + counter = Counter() + for manifest_path in args.manifest_paths: - manifest_jsons = read_manifest(manifest_path) - for line_json in manifest_jsons: - line = line_json['text'] - enc_line = text_feature.spm_tokenize(line) - for code in enc_line: - vocabs.add(code) - #print(" ".join(enc_line)) - vocabs_sorted = sorted(vocabs) - for unit in vocabs_sorted: - fout.write(unit + "\n") - - print(f"spm vocab size: {len(vocabs_sorted)}") + count_manifest(counter, text_feature, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + for token, count in count_sorted: + fout.write(token + '\n') + + print(f"spm vocab size: {len(count_sorted)}") fout.write(SOS + "\n") # fout.close() diff --git a/utils/compute_mean_std.py b/utils/compute_mean_std.py index 339813748..29fadbada 100644 --- a/utils/compute_mean_std.py +++ b/utils/compute_mean_std.py @@ -28,12 +28,13 @@ add_arg('specgram_type', str, 'linear', "Audio feature type. Options: linear, mfcc, fbank.", choices=['linear', 'mfcc', 'fbank']) -add_arg('feat_dim', int, - 13, - "Audio feature dim.") +add_arg('feat_dim', int, 13, "Audio feature dim.") add_arg('delta_delta', bool, False, "Audio feature with delta delta.") +add_arg('stride_ms', float, 10.0, "stride length in ms.") +add_arg('window_ms', float, 20.0, "stride length in ms.") +add_arg('sample_rate', int, 16000, "target sample rate.") add_arg('manifest_path', str, 'data/librispeech/manifest.train', "Filepath of manifest to compute normalizer's mean and stddev.") @@ -51,7 +52,14 @@ def main(): audio_featurizer = AudioFeaturizer( specgram_type=args.specgram_type, feat_dim=args.feat_dim, - delta_delta=args.delta_delta) + delta_delta=args.delta_delta, + stride_ms=args.stride_ms, + window_ms=args.window_ms, + n_fft=None, + max_freq=None, + target_sample_rate=args.sample_rate, + use_dB_normalization=True, + target_dB=-20) def augment_and_featurize(audio_segment): augmentation_pipeline.transform_audio(audio_segment)