diff --git a/paddlespeech/cli/ssl/infer.py b/paddlespeech/cli/ssl/infer.py index 154c25f5..dce7c778 100644 --- a/paddlespeech/cli/ssl/infer.py +++ b/paddlespeech/cli/ssl/infer.py @@ -25,6 +25,7 @@ import librosa import numpy as np import paddle import soundfile +from paddlenlp.transformers import AutoTokenizer from yacs.config import CfgNode from ..executor import BaseExecutor @@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor): self.parser.add_argument( '--model', type=str, - default='wav2vec2ASR_librispeech', + default=None, choices=[ tag[:tag.index('-')] for tag in self.task_resource.pretrained_models.keys() @@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor): help='Increase logger verbosity of current task.') def _init_from_path(self, - model_type: str='wav2vec2ASR_librispeech', + model_type: str=None, task: str='asr', lang: str='en', sample_rate: int=16000, @@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor): Init model and other resources from a specific path. """ logger.debug("start to init the model") + + if model_type is None: + if lang == 'en': + model_type = 'wav2vec2ASR_librispeech' + elif lang == 'zh': + model_type = 'wav2vec2ASR_aishell1' + else: + logger.error( + "invalid lang, please input --lang en or --lang zh") + logger.debug( + "Model type had not been specified, default {} was used.". + format(model_type)) # default max_len: unit:second self.max_len = 50 if hasattr(self, 'model'): @@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor): self.config.merge_from_file(self.cfg_path) if task == 'asr': with UpdateConfig(self.config): - self.text_feature = TextFeaturizer( - unit_type=self.config.unit_type, - vocab=self.config.vocab_filepath) + if lang == 'en': + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath) + elif lang == 'zh': + self.text_feature = AutoTokenizer.from_pretrained( + self.config.tokenizer) self.config.decode.decoding_method = decode_method model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} @@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor): audio, text_feature=self.text_feature, decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size) + beam_size=cfg.beam_size, + tokenizer=getattr(self.config, 'tokenizer', None)) self._outputs["result"] = result_transcripts[0][0] except Exception as e: logger.exception(e) @@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor): @stats_wrapper def __call__(self, audio_file: os.PathLike, - model: str='wav2vec2ASR_librispeech', + model: str=None, task: str='asr', lang: str='en', sample_rate: int=16000, diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 06724674..d3b5a8f3 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = { 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', }, }, + "wav2vec2-zh-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz', + 'md5': + '00ea4975c05d1bb58181205674052fe1', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'chinese-wav2vec2-large', + 'model': + 'chinese-wav2vec2-large.pdparams', + 'params': + 'chinese-wav2vec2-large.pdparams', + }, + }, + "wav2vec2ASR_aishell1-zh-16k": { + '1.3': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz', + 'md5': + 'ac8fa0a6345e6a7535f6fabb5e59e218', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/wav2vec2ASR/checkpoints/avg_1', + 'model': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + 'params': + 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', + }, + }, } # --------------------------------- diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py index 5670cb53..688bf5f8 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py @@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure(): self.proj_codevector_dim = config.proj_codevector_dim self.diversity_loss_weight = config.diversity_loss_weight - # ctc loss - self.ctc_loss_reduction = config.ctc_loss_reduction - self.ctc_zero_infinity = config.ctc_zero_infinity - # adapter self.add_adapter = config.add_adapter self.adapter_kernel_size = config.adapter_kernel_size diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index eda188da..dc6c6d1d 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer): feats: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, - beam_size: int): + beam_size: int, + tokenizer: str=None): batch_size = feats.shape[0] if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: - logger.error( - f'decoding mode {decoding_method} must be running with batch_size == 1' + raise ValueError( + f"decoding mode {decoding_method} must be running with batch_size == 1" ) - logger.error(f"current batch_size is {batch_size}") - sys.exit(1) if decoding_method == 'ctc_greedy_search': - hyps = self.ctc_greedy_search(feats) - res = [text_feature.defeaturize(hyp) for hyp in hyps] - res_tokenids = [hyp for hyp in hyps] + if tokenizer is None: + hyps = self.ctc_greedy_search(feats) + res = [text_feature.defeaturize(hyp) for hyp in hyps] + res_tokenids = [hyp for hyp in hyps] + else: + hyps = self.ctc_greedy_search(feats) + res = [] + res_tokenids = [] + for sequence in hyps: + # Decode token terms to words + predicted_tokens = text_feature.convert_ids_to_tokens( + sequence) + tmp_res = [] + tmp_res_tokenids = [] + for c in predicted_tokens: + if c == "[CLS]": + continue + elif c == "[SEP]" or c == "[PAD]": + break + else: + tmp_res.append(c) + tmp_res_tokenids.append(text_feature.vocab[c]) + res.append(''.join(tmp_res)) + res_tokenids.append(tmp_res_tokenids) # ctc_prefix_beam_search and attention_rescoring only return one # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': assert feats.shape[0] == 1 - hyp = self.ctc_prefix_beam_search(feats, beam_size) - res = [text_feature.defeaturize(hyp)] - res_tokenids = [hyp] + if tokenizer is None: + hyp = self.ctc_prefix_beam_search(feats, beam_size) + res = [text_feature.defeaturize(hyp)] + res_tokenids = [hyp] + else: + hyp = self.ctc_prefix_beam_search(feats, beam_size) + res = [] + res_tokenids = [] + predicted_tokens = text_feature.convert_ids_to_tokens(hyp) + tmp_res = [] + tmp_res_tokenids = [] + for c in predicted_tokens: + if c == "[CLS]": + continue + elif c == "[SEP]" or c == "[PAD]": + break + else: + tmp_res.append(c) + tmp_res_tokenids.append(text_feature.vocab[c]) + res.append(''.join(tmp_res)) + res_tokenids.append(tmp_res_tokenids) else: raise ValueError( f"wav2vec2 not support decoding method: {decoding_method}")