diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py index f3125e04d..01f01b651 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -18,8 +18,10 @@ import numpy as np import paddle from paddle.inference import Config from paddle.inference import create_predictor +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -78,26 +80,31 @@ def inference(config, args): def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + audio_len = feature[0].shape[0] audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -138,7 +145,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8089, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index b2ff37e06..b473a8fd4 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -16,8 +16,10 @@ import functools import numpy as np import paddle +from paddle.io import DataLoader from deepspeech.exps.deepspeech2.config import get_cfg_defaults +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.training.cli import default_argument_parser @@ -31,26 +33,35 @@ from deepspeech.utils.utility import print_arguments def start_server(config, args): """Start the ASR server""" config.defrost() - config.data.manfiest = config.data.test_manifest - config.data.augmentation_config = "" - config.data.keep_transcription_text = True + config.data.manifest = config.data.test_manifest dataset = ManifestDataset.from_config(config) - model = DeepSpeech2Model.from_pretrained(dataset, config, + config.collator.augmentation_config = "" + config.collator.keep_transcription_text = True + config.collator.batch_size = 1 + config.collator.num_workers = 0 + collate_fn = SpeechCollator.from_config(config) + test_loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=0) + + model = DeepSpeech2Model.from_pretrained(test_loader, config, args.checkpoint_path) model.eval() # prepare ASR inference handler def file_to_transcript(filename): - feature = dataset.process_utterance(filename, "") - audio = np.array([feature[0]]).astype('float32') #[1, D, T] - audio_len = feature[0].shape[1] + feature = test_loader.collate_fn.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, T, D] + # audio = audio.swapaxes(1,2) + print('---file_to_transcript feature----') + print(audio.shape) + audio_len = feature[0].shape[0] + print(audio_len) audio_len = np.array([audio_len]).astype('int64') # [1] result_transcript = model.decode( paddle.to_tensor(audio), paddle.to_tensor(audio_len), - vocab_list=dataset.vocab_list, + vocab_list=test_loader.collate_fn.vocab_list, decoding_method=config.decoding.decoding_method, lang_model_path=config.decoding.lang_model_path, beam_alpha=config.decoding.alpha, @@ -91,7 +102,7 @@ if __name__ == "__main__": add_arg('host_ip', str, 'localhost', "Server's IP address.") - add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('host_port', int, 8088, "Server's IP port.") add_arg('speech_save_dir', str, 'demo_cache', "Directory to save demo audios.") diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index 02e329a11..f10dc27ce 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -47,7 +47,7 @@ def tune(config, args): drop_last=False, collate_fn=SpeechCollator(keep_transcription_text=True)) - model = DeepSpeech2Model.from_pretrained(dev_dataset, config, + model = DeepSpeech2Model.from_pretrained(valid_loader, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index deb8752b7..209e8b023 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -318,7 +318,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def export(self): infer_model = DeepSpeech2InferModel.from_pretrained( - self.test_loader.dataset, self.config, self.args.checkpoint_path) + self.test_loader, self.config, self.args.checkpoint_path) infer_model.eval() feat_dim = self.test_loader.collate_fn.feature_size static_model = paddle.jit.to_static( diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 055518755..308569cd7 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -506,7 +506,7 @@ class U2Tester(U2Trainer): List[paddle.static.InputSpec]: input spec. """ from deepspeech.models.u2 import U2InferModel - infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, + infer_model = U2InferModel.from_pretrained(self.test_loader, self.config.model.clone(), self.args.checkpoint_path) feat_dim = self.test_loader.collate_fn.feature_size diff --git a/deepspeech/frontend/augmentor/spec_augment.py b/deepspeech/frontend/augmentor/spec_augment.py index 453aadcba..332c07095 100644 --- a/deepspeech/frontend/augmentor/spec_augment.py +++ b/deepspeech/frontend/augmentor/spec_augment.py @@ -124,7 +124,7 @@ class SpecAugmentor(AugmentorBase): def time_warp(xs, W=40): raise NotImplementedError - def randomize_parameters(self, n_bins, n_frame): + def randomize_parameters(self, n_frame, n_bins): # n_bins = xs.shape[0] # n_frames = xs.shape[1] @@ -156,66 +156,69 @@ class SpecAugmentor(AugmentorBase): self.t_0.append(int(self._rng.uniform(low=0, high=n_frames - t))) def apply(self, xs: np.ndarray): - n_bins = xs.shape[0] - n_frames = xs.shape[1] + ''' + input xs [T, D] + ''' + n_frames = xs.shape[0] + n_bins = xs.shape[1] for i in range(0, self.n_freq_masks): f = self.f[i] f_0 = self.f_0[i] - xs[f_0:f_0 + f, :] = 0 + xs[:, f_0:f_0 + f] = 0 assert f_0 <= f_0 + f for i in range(self.n_masks): t = self.t[i] t_0 = self.t_0[i] - xs[:, t_0:t_0 + t] = 0 + xs[t_0:t_0 + t, :] = 0 assert t_0 <= t_0 + t return xs - def mask_freq(self, xs, replace_with_zero=False): - n_bins = xs.shape[0] - for i in range(0, self.n_freq_masks): - f = int(self._rng.uniform(low=0, high=self.F)) - f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) - xs[f_0:f_0 + f, :] = 0 - assert f_0 <= f_0 + f - self._freq_mask = (f_0, f_0 + f) - return xs - - def mask_time(self, xs, replace_with_zero=False): - n_frames = xs.shape[1] - - if self.adaptive_number_ratio > 0: - n_masks = int(n_frames * self.adaptive_number_ratio) - n_masks = min(n_masks, self.max_n_time_masks) - else: - n_masks = self.n_time_masks - - if self.adaptive_size_ratio > 0: - T = self.adaptive_size_ratio * n_frames - else: - T = self.T + # def mask_freq(self, xs, replace_with_zero=False): + # n_bins = xs.shape[0] + # for i in range(0, self.n_freq_masks): + # f = int(self._rng.uniform(low=0, high=self.F)) + # f_0 = int(self._rng.uniform(low=0, high=n_bins - f)) + # xs[f_0:f_0 + f, :] = 0 + # assert f_0 <= f_0 + f + # self._freq_mask = (f_0, f_0 + f) + # return xs - for i in range(n_masks): - t = int(self._rng.uniform(low=0, high=T)) - t = min(t, int(n_frames * self.p)) - t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) - xs[:, t_0:t_0 + t] = 0 - assert t_0 <= t_0 + t - self._time_mask = (t_0, t_0 + t) - return xs + # def mask_time(self, xs, replace_with_zero=False): + # n_frames = xs.shape[1] + + # if self.adaptive_number_ratio > 0: + # n_masks = int(n_frames * self.adaptive_number_ratio) + # n_masks = min(n_masks, self.max_n_time_masks) + # else: + # n_masks = self.n_time_masks + + # if self.adaptive_size_ratio > 0: + # T = self.adaptive_size_ratio * n_frames + # else: + # T = self.T + + # for i in range(n_masks): + # t = int(self._rng.uniform(low=0, high=T)) + # t = min(t, int(n_frames * self.p)) + # t_0 = int(self._rng.uniform(low=0, high=n_frames - t)) + # xs[:, t_0:t_0 + t] = 0 + # assert t_0 <= t_0 + t + # self._time_mask = (t_0, t_0 + t) + # return xs - def transform_feature(self, xs: np.ndarray, single=True): - """ - Args: - xs (FloatTensor): `[F, T]` - Returns: - xs (FloatTensor): `[F, T]` - """ - if(single): - self.randomize_parameters(xs) - return self.apply(xs) + # def transform_feature(self, xs: np.ndarray, single=True): + # """ + # Args: + # xs (FloatTensor): `[F, T]` + # Returns: + # xs (FloatTensor): `[F, T]` + # """ + # if(single): + # self.randomize_parameters(xs) + # return self.apply(xs) # def transform_feature(self, xs: np.ndarray): # """ diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 11c1fa2d4..f209d305d 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -221,17 +221,19 @@ class AudioFeaturizer(object): """append delat, delta-delta feature. Args: - feat (np.ndarray): (D, T) + feat (np.ndarray): (T, D) Returns: - np.ndarray: feat with delta-delta, (3*D, T) + np.ndarray: feat with delta-delta, (T, 3*D) """ + # transpose (T, D) --> (D, T) feat = np.transpose(feat) # Deltas d_feat = delta(feat, 2) # Deltas-Deltas dd_feat = delta(feat, 2) # transpose + # transpose (D, T) --> (T, D) feat = np.transpose(feat) d_feat = np.transpose(d_feat) dd_feat = np.transpose(dd_feat) @@ -264,7 +266,7 @@ class AudioFeaturizer(object): ValueError: stride_ms > window_ms Returns: - np.ndarray: mfcc feature, (D, T). + np.ndarray: mfcc feature, (T, D). """ if max_freq is None: max_freq = sample_rate / 2 @@ -322,7 +324,7 @@ class AudioFeaturizer(object): ValueError: stride_ms > window_ms Returns: - np.ndarray: mfcc feature, (D, T). + np.ndarray: mfcc feature, (T, D). """ if max_freq is None: max_freq = sample_rate / 2 diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 8a2a78ef3..bfba3c55b 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -229,7 +229,7 @@ class SpeechCollator(): def randomize_feature_parameters(self, n_bins, n_frames): self._augmentation_pipeline.andomize_parameters_feature_transform(n_bins, n_frames) - def process_utterance(self, audio_file, transcript): + def process_feature_and_transform(self, audio_file, transcript): """Load, augment, featurize and normalize for speech data. :param audio_file: Filepath or file object of audio file. @@ -254,6 +254,7 @@ class SpeechCollator(): # # apply specgram augment # specgram = self._augmentation_pipeline.apply_feature_transform(specgram) + return specgram, transcript_part @@ -318,12 +319,12 @@ class SpeechCollator(): for utt, audio, text in batch: if not self.config.randomize_each_batch: self.randomize_audio_parameters() - audio, text = self.process_utterance(audio, text) + audio, text = self.process_feature_and_transform(audio, text) #utt utts.append(utt) # audio - audios.append(audio.T) # [T, D] - audio_lens.append(audio.shape[1]) + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) # text # for training, text is token ids # else text is string, convert to unicode ord @@ -346,8 +347,8 @@ class SpeechCollator(): text_lens = np.array(text_lens).astype(np.int64) #spec augment - n_bins=padded_audios[0] - self.randomize_feature_parameters(n_bins, min(audio_lens)) + n_bins=padded_audios.shape[2] + self.randomize_feature_parameters(min(audio_lens), n_bins) for i in range(len(padded_audios)): if not self.config.randomize_each_batch: self.randomize_feature_parameters(n_bins, audio_lens[i]) diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514de..d2c03a18e 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -198,11 +198,11 @@ class DeepSpeech2Model(nn.Layer): cutoff_top_n, num_processes) @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Parameters ---------- - dataset: paddle.io.Dataset + dataloader: paddle.io.DataLoader config: yacs.config.CfgNode model configs @@ -215,8 +215,8 @@ class DeepSpeech2Model(nn.Layer): DeepSpeech2Model The model built from pretrained result. """ - model = cls(feat_size=dataset.feature_size, - dict_size=dataset.vocab_size, + model = cls(feat_size=dataloader.collate_fn.feature_size, + dict_size=dataloader.collate_fn.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35c..23ae3423d 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -876,11 +876,11 @@ class U2Model(U2BaseModel): return model @classmethod - def from_pretrained(cls, dataset, config, checkpoint_path): + def from_pretrained(cls, dataloader, config, checkpoint_path): """Build a DeepSpeech2Model model from a pretrained model. Args: - dataset (paddle.io.Dataset): not used. + dataloader (paddle.io.DataLoader): not used. config (yacs.config.CfgNode): model configs checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name @@ -888,8 +888,8 @@ class U2Model(U2BaseModel): DeepSpeech2Model: The model built from pretrained result. """ config.defrost() - config.input_dim = dataset.feature_size - config.output_dim = dataset.vocab_size + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size config.freeze() model = cls.from_config(config) diff --git a/deepspeech/utils/socket_server.py b/deepspeech/utils/socket_server.py index adcbf3bb2..45c659f60 100644 --- a/deepspeech/utils/socket_server.py +++ b/deepspeech/utils/socket_server.py @@ -48,9 +48,9 @@ def warm_up_test(audio_process_handler, rng = random.Random(random_seed) samples = rng.sample(manifest, num_test_cases) for idx, sample in enumerate(samples): - print("Warm-up Test Case %d: %s", idx, sample['audio_filepath']) + print("Warm-up Test Case %d: %s" % (idx, sample['feat'])) start_time = time.time() - transcript = audio_process_handler(sample['audio_filepath']) + transcript = audio_process_handler(sample['feat']) finish_time = time.time() print("Response Time: %f, Transcript: %s" % (finish_time - start_time, transcript)) diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 8c1a51b62..c25888457 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -2,9 +2,10 @@ ## Deepspeech2 -| Model | release | Config | Test set | Loss | CER | -| --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | -| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | -| DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | -| DeepSpeech2 | 1.8.5 | - | test | - | 0.080447 | +| Model | Params | Release | Config | Test set | Loss | CER | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 58.4M | 2.2.0 | conf/deepspeech2.yaml + spec aug + new datapipe | test | 6.396368026733398 | 0.068382 | +| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml + spec aug | test | 7.483316898345947 | 0.077860 | +| DeepSpeech2 | 58.4M | 2.1.0 | conf/deepspeech2.yaml | test | 7.299022197723389 | 0.078671 | +| DeepSpeech2 | 58.4M | 2.0.0 | conf/deepspeech2.yaml | test | - | 0.078977 | +| DeepSpeech2 | 58.4M | 1.8.5 | - | test | - | 0.080447 | diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 8cc4c4c9c..1004fde0e 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -10,8 +10,8 @@ data: min_output_input_ratio: 0.00 max_output_input_ratio: .inf - collator: + batch_size: 64 # one gpu mean_std_filepath: data/mean_std.json unit_type: char vocab_filepath: data/vocab.txt @@ -33,7 +33,6 @@ collator: sortagrad: True shuffle_method: batch_shuffle num_workers: 0 - batch_size: 64 # one gpu model: num_conv_layers: 2 diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 4073c81b9..c9708dcc9 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -2,7 +2,7 @@ set -e source path.sh -gpus=0 +gpus=0,1,2,3 stage=0 stop_stage=100 conf_path=conf/deepspeech2.yaml @@ -31,10 +31,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES=${gpus} ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/aishell/s1/README.md b/examples/aishell/s1/README.md index 601b0a8d0..72a03b618 100644 --- a/examples/aishell/s1/README.md +++ b/examples/aishell/s1/README.md @@ -2,21 +2,21 @@ ## Conformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | -| conformer | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | +| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | +| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | +| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | ## Chunk Conformer -| Model | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | -| conformer | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16 | - | 0.059400 | +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention | 16 | - | 0.061939 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_greedy_search | 16 | - | 0.070806 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | 16 | - | 0.070739 | +| conformer | 47.06M | conf/chunk_conformer.yaml | spec_aug + shift | test | attention_rescoring | 16 | - | 0.059400 | ## Transformer diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 393dd4579..dde288bdd 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -2,8 +2,8 @@ ## Deepspeech2 -| Model | release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | -| DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 | +| Model | Params | Release | Config | Test set | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index d1746bff3..b419cbe26 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -3,16 +3,21 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev-clean test_manifest: data/manifest.test-clean - mean_std_filepath: data/mean_std.json - vocab_filepath: data/vocab.txt - augmentation_config: conf/augmentation.json - batch_size: 20 min_input_len: 0.0 max_input_len: 27.0 # second min_output_len: 0.0 max_output_len: .inf min_output_input_ratio: 0.00 max_output_input_ratio: .inf + +collator: + batch_size: 20 + mean_std_filepath: data/mean_std.json + unit_type: char + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.json + random_seed: 0 + spm_model_prefix: specgram_type: linear target_sample_rate: 16000 max_freq: None diff --git a/tools/Makefile b/tools/Makefile index c129bf5a2..dd5902373 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -31,5 +31,5 @@ sox.done: soxbindings.done: test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git - source venv/bin/activate; cd soxbindings && python3 setup.py install - touch soxbindings.done + source venv/bin/activate; cd soxbindings && python setup.py install + touch soxbindings.done \ No newline at end of file