diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index aa4e31d9..447b0a1a 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import io import os import sys from typing import List @@ -23,9 +22,9 @@ import librosa import numpy as np import paddle import soundfile -import yaml from yacs.config import CfgNode +from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register @@ -64,14 +63,47 @@ pretrained_models = { 'ckpt_path': 'exp/transformer/checkpoints/avg_10', }, + "deepspeech2offline_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + "deepspeech2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + 'd5e076217cf60486519f72c217d21b9b', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } model_alias = { - "deepspeech2offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "deepspeech2online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", - "conformer": "paddlespeech.s2t.models.u2:U2Model", - "transformer": "paddlespeech.s2t.models.u2:U2Model", - "wenetspeech": "paddlespeech.s2t.models.u2:U2Model", + "deepspeech2offline": + "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "deepspeech2online": + "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "conformer": + "paddlespeech.s2t.models.u2:U2Model", + "transformer": + "paddlespeech.s2t.models.u2:U2Model", + "wenetspeech": + "paddlespeech.s2t.models.u2:U2Model", } @@ -95,7 +127,8 @@ class ASRExecutor(BaseExecutor): '--lang', type=str, default='zh', - help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]') + help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]' + ) self.parser.add_argument( "--sample_rate", type=int, @@ -111,7 +144,10 @@ class ASRExecutor(BaseExecutor): '--decode_method', type=str, default='attention_rescoring', - choices=['ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', 'attention_rescoring'], + choices=[ + 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', + 'attention_rescoring' + ], help='only support transformer and conformer model') self.parser.add_argument( '--ckpt_path', @@ -187,13 +223,21 @@ class ASRExecutor(BaseExecutor): if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator self.vocab = self.config.vocab_filepath - self.config.decode.lang_model_path = os.path.join(res_path, self.config.decode.lang_model_path) + self.config.decode.lang_model_path = os.path.join( + MODEL_HOME, 'language_model', + self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.vocab) + unit_type=self.config.unit_type, vocab=self.vocab) + lm_url = pretrained_models[tag]['lm_url'] + lm_md5 = pretrained_models[tag]['lm_md5'] + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - self.config.spm_model_prefix = os.path.join(self.res_path, self.config.spm_model_prefix) + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, @@ -319,6 +363,13 @@ class ASRExecutor(BaseExecutor): """ return self._outputs["result"] + def download_lm(self, url, lm_dir, md5sum): + download_path = get_path_from_url( + url=url, + root_dir=lm_dir, + md5sum=md5sum, + decompress=False, ) + def _pcm16to32(self, audio): assert (audio.dtype == np.int16) audio = audio.astype("float32") @@ -411,7 +462,7 @@ class ASRExecutor(BaseExecutor): try: res = self(audio_file, model, lang, sample_rate, config, ckpt_path, - decode_method, force_yes, device) + decode_method, force_yes, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: @@ -435,7 +486,8 @@ class ASRExecutor(BaseExecutor): audio_file = os.path.abspath(audio_file) self._check(audio_file, sample_rate, force_yes) paddle.set_device(device) - self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) + self._init_from_path(model, lang, sample_rate, config, decode_method, + ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr.