From 2a42421a63fa22a1bb7547fffc19ecadbe3633f6 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Thu, 27 Jan 2022 06:06:08 +0000 Subject: [PATCH] cli add ds2-librispeech offline, fix versionm, test=asr --- paddlespeech/cli/asr/infer.py | 35 ++++++++++++------- paddlespeech/cli/utils.py | 2 +- paddlespeech/s2t/io/sampler.py | 2 +- .../t2s/modules/transformer/repeat.py | 2 +- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 447b0a1a0..64b325201 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -91,6 +91,20 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "deepspeech2offline_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + 'f5666c81ad015c8de03aac2bc92e5762', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', + 'lm_md5': + '099a601759d467cd0a8523ff939819c5' + }, } model_alias = { @@ -328,18 +342,15 @@ class ASRExecutor(BaseExecutor): audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - self.text_feature.vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) + decode_batch_size = audio.shape[0] + self.model.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + result_transcripts = self.model.decode(audio, audio_len) + self.model.decoder.del_decoder() self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index 4f2c89065..d7dcc90c7 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -34,7 +34,7 @@ from .entry import commands try: from .. import __version__ except ImportError: - __version__ = 0.0.0 # for develop branch + __version__ = "0.0.0" # for develop branch requests.adapters.DEFAULT_RETRIES = 3 diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index ac55af123..89752bb9f 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py index 0325a6382..f738b5569 100644 --- a/paddlespeech/t2s/modules/transformer/repeat.py +++ b/paddlespeech/t2s/modules/transformer/repeat.py @@ -41,4 +41,4 @@ def repeat(N, fn): MultiSequential Repeated model instance. """ - return MultiSequential(* [fn(n) for n in range(N)]) + return MultiSequential(*[fn(n) for n in range(N)])