diff --git a/.notebook/jit_infer.ipynb b/.notebook/jit_infer.ipynb index a62e76a2e..019fcf393 100644 --- a/.notebook/jit_infer.ipynb +++ b/.notebook/jit_infer.ipynb @@ -295,6 +295,7 @@ "source": [ "dataset = ManifestDataset(\n", " config.data.test_manifest,\n", + " config.data.unit_type,\n", " config.data.vocab_filepath,\n", " config.data.mean_std_filepath,\n", " augmentation_config=\"{}\",\n", diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index 22dc9ad57..5948fbd48 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -85,8 +85,10 @@ def start_server(config, args): """Start the ASR server""" dataset = ManifestDataset( config.data.test_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index 6b99adc3f..5f72b1600 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -37,8 +37,10 @@ def start_server(config, args): """Start the ASR server""" dataset = ManifestDataset( config.data.test_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 50de94c3b..3df9fb314 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -43,8 +43,10 @@ def tune(config, args): dev_dataset = ManifestDataset( config.data.dev_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 968899d75..1762aeadf 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -21,7 +21,9 @@ _C.data = CN( train_manifest="", dev_manifest="", test_manifest="", + unit_type="char", vocab_filepath="", + spm_model_prefix="", mean_std_filepath="", augmentation_config="", max_duration=float('inf'), diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index e6779be63..13fe0dca5 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -148,8 +148,10 @@ class DeepSpeech2Trainer(Trainer): train_dataset = ManifestDataset( config.data.train_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config=io.open( config.data.augmentation_config, mode='r', encoding='utf8').read(), @@ -168,8 +170,10 @@ class DeepSpeech2Trainer(Trainer): dev_dataset = ManifestDataset( config.data.dev_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, @@ -361,8 +365,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): # return raw text test_dataset = ManifestDataset( config.data.test_manifest, + config.data.unit_type, config.data.vocab_filepath, config.data.mean_std_filepath, + spm_model_prefix=config.data.spm_model_prefix, augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index b5edb32d5..799525e55 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -109,7 +109,7 @@ class AudioFeaturizer(object): feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 + 1) elif self._specgram_type == 'mfcc': - # mfcc,delta, delta-delta + # mfcc, delta, delta-delta feat_dim = int(13 * 3) else: raise ValueError("Unknown specgram_type %s. " diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index d4de96adc..894c684bf 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -52,7 +52,9 @@ class SpeechFeaturizer(object): """ def __init__(self, + unit_type, vocab_filepath, + spm_model_prefix=None, specgram_type='linear', stride_ms=10.0, window_ms=20.0, @@ -70,7 +72,8 @@ class SpeechFeaturizer(object): target_sample_rate=target_sample_rate, use_dB_normalization=use_dB_normalization, target_dB=target_dB) - self._text_featurizer = TextFeaturizer(vocab_filepath) + self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, + spm_model_prefix) def featurize(self, speech_segment, keep_transcription_text): """Extract features for speech segment. @@ -116,4 +119,4 @@ class SpeechFeaturizer(object): :return: audio feature size. :rtype: int """ - return self._audio_featurizer.feature_size \ No newline at end of file + return self._audio_featurizer.feature_size diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index a1e8cdbb1..db9734f1e 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -15,25 +15,35 @@ import os import codecs +import sentencepiece as spm +from deepspeech.frontend.utility import UNK -class TextFeaturizer(object): - """Text featurizer, for processing or extracting features from text. - Currently, it only supports char-level tokenizing and conversion into - a list of token indices. Note that the token indexing order follows the - given vocabulary file. +class TextFeaturizer(object): + def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None): + """Text featurizer, for processing or extracting features from text. - :param vocab_filepath: Filepath to load vocabulary for token indices - conversion. - :type specgram_type: str - """ + Currently, it supports char/word/sentence-piece level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. - def __init__(self, vocab_filepath): - self.unk = '' + Args: + unit_type (str): unit type, e.g. char, word, spm + vocab_filepath (str): Filepath to load vocabulary for token indices conversion. + spm_model_prefix (str, optional): spm model prefix. Defaults to None. + """ + assert unit_type in ('char', 'spm', 'word') + self.unk = UNK + self.unit_type = unit_type self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) + if unit_type == 'spm': + spm_model = spm_model_prefix + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(self.spm_model) + def featurize(self, text): """Convert text string to a list of token indices in char-level.Note that the token indexing order follows the given vocabulary file. @@ -43,7 +53,13 @@ class TextFeaturizer(object): :return: List of char-level token indices. :rtype: list """ - tokens = self._char_tokenize(text) + if unit_type == 'char': + tokens = self._char_tokenize(text) + elif unit_type == 'word': + tokens = self._word_tokenize(text) + else: + tokens = self._spm_tokenize(text) + ids = [] for token in tokens: token = token if token in self._vocab_dict else self.unk @@ -72,6 +88,42 @@ class TextFeaturizer(object): """Character tokenizer.""" return list(text.strip()) + def _word_tokenize(self, text): + """Word tokenizer, spearte by .""" + return text.strip().split() + + def _spm_tokenize(self, text): + """spm tokenize. + + Args: + text (str): text string. + + Returns: + List[str]: sentence pieces str code + """ + stats = {"num_empty": 0, "num_filtered": 0} + + def valid(line): + return True + + def encode(l): + return self.sp.EncodeAsPieces(l) + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + enc_line = encode_line(text) + return enc_line + def _load_vocabulary_from_file(self, vocab_filepath): """Load vocabulary from file.""" vocab_lines = [] diff --git a/deepspeech/frontend/normalizer.py b/deepspeech/frontend/normalizer.py index f8ee52f03..a57b247ad 100644 --- a/deepspeech/frontend/normalizer.py +++ b/deepspeech/frontend/normalizer.py @@ -16,6 +16,7 @@ import numpy as np import random from deepspeech.frontend.utility import read_manifest +from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.audio import AudioSegment @@ -79,10 +80,8 @@ class FeatureNormalizer(object): def _read_mean_std_from_file(self, filepath, eps=1e-20): """Load mean and std from file.""" - npzfile = np.load(filepath) - self._mean = npzfile["mean"] - std = npzfile["std"] - std = np.clip(std, eps, None) + mean, std = load_cmvn(filepath, filetype='npz') + self._mean = mean self._istd = 1.0 / std def _compute_mean_std(self, manifest_path, featurize_func, num_samples): @@ -92,8 +91,7 @@ class FeatureNormalizer(object): features = [] for instance in sampled_manifest: features.append( - featurize_func( - AudioSegment.from_file(instance["audio_filepath"]))) + featurize_func(AudioSegment.from_file(instance["feat"]))) features = np.hstack(features) #(D, T) self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1) self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1) diff --git a/deepspeech/frontend/utility.py b/deepspeech/frontend/utility.py index 0f35b1ef5..f2a53833b 100644 --- a/deepspeech/frontend/utility.py +++ b/deepspeech/frontend/utility.py @@ -20,6 +20,7 @@ import os import tarfile import time import logging +from typing import List from threading import Thread from multiprocessing import Process, Manager, Value @@ -39,31 +40,32 @@ EOS = SOS UNK = "" BLANK = "" -# """Load and parse manifest file. - -# Instances with durations outside [min_duration, max_duration] will be -# filtered out. - -# :param manifest_path: Manifest file to load and parse. -# :type manifest_path: str -# :param max_duration:maximum output seq length, in seconds for raw wav, in frame numbers for feature data. -# :type max_duration: float -# :param min_duration: minimum input seq length, in seconds for raw wav, in frame numbers for feature data. -# :type min_duration: float -# :return: Manifest parsing results. List of dict. -# :rtype: list -# :raises IOError: If failed to parse the manifest. -# """ - def read_manifest( manifest_path, max_input_len=float('inf'), min_input_len=0.0, - max_output_len=500.0, + max_output_len=float('inf'), min_output_len=0.0, - max_output_input_ratio=10.0, - min_output_input_ratio=0.05, ): + max_output_input_ratio=float('inf'), + min_output_input_ratio=0.0, ): + """Load and parse manifest file. + + Args: + manifest_path ([type]): Manifest file to load and parse. + max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). + min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. + max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. + min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. + max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. + min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. + + Raises: + IOError: If failed to parse the manifest. + + Returns: + List[dict]: Manifest parsing results. + """ manifest = [] for json_line in codecs.open(manifest_path, 'r', 'utf-8'): @@ -71,33 +73,23 @@ def read_manifest( json_data = json.loads(json_line) except Exception as e: raise IOError("Error reading manifest: %s" % str(e)) - feat_len = json_data["feat_shape"][0] - token_len = json_data["token_shape"][0] + + feat_len = json_data["feat_shape"][ + 0] if 'feat_shape' in json_data else 1.0 + token_len = json_data["token_shape"][ + 0] if 'token_shape' in json_data else 1.0 conditions = [ - feat_len > min_input_len, - feat_len < max_input_len, - token_len > min_output_len, - token_len < max_output_len, - token_len / feat_len > min_output_input_ratio, - token_len / feat_len < max_output_input_ratio, + feat_len >= min_input_len, + feat_len <= max_input_len, + token_len >= min_output_len, + token_len <= max_output_len, + token_len / feat_len >= min_output_input_ratio, + token_len / feat_len <= max_output_input_ratio, ] if all(conditions): manifest.append(json_data) return manifest - # parser.add_argument('--max_input_len', type=float, - # default=20, - # help='maximum output seq length, in seconds for raw wav, in frame numbers for feature data') - # parser.add_argument('--min_output_len', type=float, - # default=0, help='minimum input seq length, in modeling units') - # parser.add_argument('--max_output_len', type=float, - # default=500, - # help='maximum output seq length, in modeling units') - # parser.add_argument('--min_output_input_ratio', type=float, default=0.05, - # help='minimum output seq length/output seq length ratio') - # parser.add_argument('--max_output_input_ratio', type=float, default=10, - # help='maximum output seq length/output seq length ratio') - def rms_to_db(rms: float): """Root Mean Square to dB. @@ -251,8 +243,8 @@ def _load_kaldi_cmvn(kaldi_cmvn_file): def _load_npz_cmvn(npz_cmvn_file, eps=1e-20): npzfile = np.load(npz_cmvn_file) - means = npzfile["mean"] - std = npzfile["std"] + means = npzfile["mean"] #(D, 1) + std = npzfile["std"] #(D, 1) std = np.clip(std, eps, None) variance = 1.0 / std cmvn = np.array([means, variance]) @@ -278,7 +270,7 @@ def load_cmvn(cmvn_file: str, filetype: str): cmvn = _load_json_cmvn(cmvn_file) elif filetype == "kaldi": cmvn = _load_kaldi_cmvn(cmvn_file) - elif filtype == "npz": + elif filetype == "npz": cmvn = _load_npz_cmvn(cmvn_file) else: raise ValueError(f"cmvn file type no support: {filetype}") diff --git a/deepspeech/io/__init__.py b/deepspeech/io/__init__.py index 12e1d4d91..aa638179e 100644 --- a/deepspeech/io/__init__.py +++ b/deepspeech/io/__init__.py @@ -21,8 +21,10 @@ from deepspeech.io.dataset import ManifestDataset def create_dataloader(manifest_path, + unit_type, vocab_filepath, mean_std_filepath, + spm_model_prefix, augmentation_config='{}', max_duration=float('inf'), min_duration=0.0, @@ -42,8 +44,10 @@ def create_dataloader(manifest_path, dataset = ManifestDataset( manifest_path, + unit_type, vocab_filepath, mean_std_filepath, + spm_model_prefix=spm_model_prefix, augmentation_config=augmentation_config, max_duration=max_duration, min_duration=min_duration, diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index b4c1c7afd..c22e9d16d 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -38,8 +38,10 @@ __all__ = [ class ManifestDataset(Dataset): def __init__(self, manifest_path, + unit_type, vocab_filepath, mean_std_filepath, + spm_model_prefix=None, augmentation_config='{}', max_duration=float('inf'), min_duration=0.0, @@ -57,8 +59,10 @@ class ManifestDataset(Dataset): Args: manifest_path (str): manifest josn file path - vocab_filepath (str): vocab file path + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. mean_std_filepath (str): mean and std file path, which suffix is *.npy + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. augmentation_config (str, optional): augmentation json str. Defaults to '{}'. max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf'). min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0. @@ -78,10 +82,12 @@ class ManifestDataset(Dataset): self._max_duration = max_duration self._min_duration = min_duration self._normalizer = FeatureNormalizer(mean_std_filepath) - self._augmentation_pipeline = AugmentationPipeline( + self._audio_augmentation_pipeline = AugmentationPipeline( augmentation_config=augmentation_config, random_seed=random_seed) self._speech_featurizer = SpeechFeaturizer( + unit_type=unit_type, vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, @@ -174,7 +180,7 @@ class ManifestDataset(Dataset): self._subfile_from_tar(audio_file), transcript) else: speech_segment = SpeechSegment.from_file(audio_file, transcript) - self._augmentation_pipeline.transform_audio(speech_segment) + self._audio_augmentation_pipeline.transform_audio(speech_segment) specgram, transcript_part = self._speech_featurizer.featurize( speech_segment, self._keep_transcription_text) specgram = self._normalizer.apply(specgram) @@ -191,7 +197,7 @@ class ManifestDataset(Dataset): def reader(): for instance in manifest: - inst = self.process_utterance(instance["audio_filepath"], + inst = self.process_utterance(instance["feat"], instance["text"]) yield inst @@ -202,5 +208,4 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["audio_filepath"], - instance["text"]) + return self.process_utterance(instance["feat"], instance["text"]) diff --git a/examples/tiny/s0/local/data.sh b/examples/tiny/s0/local/data.sh index 0117218f8..240655847 100644 --- a/examples/tiny/s0/local/data.sh +++ b/examples/tiny/s0/local/data.sh @@ -52,6 +52,7 @@ fi # format manifest with tokenids, vocab size python3 ${MAIN_ROOT}/utils/format_data.py \ --feat_type "raw" \ +--cmvn_path "data/mean_std.npz" \ --unit_type "bpe" \ --bpe_model_prefix ${bpeprefix} \ --vocab_path="data/vocab.txt" \ diff --git a/requirements.txt b/requirements.txt index d95db07bd..1ef11e17d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ resampy==0.2.2 SoundFile==0.9.0.post1 python_speech_features tensorboardX +sentencepiece yacs typeguard pre-commit diff --git a/utils/format_data.py b/utils/format_data.py index fc10c0385..4788f8579 100644 --- a/utils/format_data.py +++ b/utils/format_data.py @@ -24,6 +24,7 @@ from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import UNK from deepspeech.frontend.utility import BLANK from deepspeech.frontend.utility import SOS +from deepspeech.frontend.utility import load_cmvn from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import print_arguments @@ -31,10 +32,13 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") +add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.npz', + "Filepath of cmvn.") add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe") add_arg('vocab_path', str, 'examples/librispeech/data/vocab.txt', - "Filepath to write the vocabulary.") + "Filepath of the vocabulary.") add_arg('manifest_paths', str, None, "Filepaths of manifests for building vocabulary. " @@ -51,6 +55,11 @@ args = parser.parse_args() def main(): print_arguments(args) + # get feat dim + mean, std = load_cmvn(args.cmvn_path, filetype='npz') + feat_dim = mean.shape[0] + print(f"Feature dim: {feat_dim}") + # read vocab vocab = dict() with open(args.vocab_path, 'r', encoding='utf-8') as fin: @@ -58,6 +67,7 @@ def main(): token = line.strip() vocab[token] = len(vocab) vocab_size = len(vocab) + print(f"Vocab size: {vocab_size}") fout = open(args.output_path, 'w', encoding='utf-8') @@ -78,6 +88,12 @@ def main(): line_json['token'] = tokens line_json['token_id'] = tokenids line_json['token_shape'] = (len(tokenids), vocab_size) + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + if args.feat_type == 'raw': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplemented('no support kaldi feat now!') fout.write(json.dumps(line_json) + '\n') else: import sentencepiece as spm @@ -118,6 +134,12 @@ def main(): line_json['token'] = tokens line_json['token_id'] = tokenids line_json['token_shape'] = (len(tokenids), vocab_size) + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + if args.feat_type == 'raw': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplemented('no support kaldi feat now!') fout.write(json.dumps(line_json) + '\n') fout.close()