[ASR] support wav2vec2-zh cli, test=asr (#2697)

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr

* support wav2vec2-zh cli, test=asr
pull/2704/head
Zth9730 2 years ago committed by GitHub
parent a01c163dc3
commit c67bf7b4ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,6 +25,7 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
from paddlenlp.transformers import AutoTokenizer
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='wav2vec2ASR_librispeech', default=None,
choices=[ choices=[
tag[:tag.index('-')] tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys() for tag in self.task_resource.pretrained_models.keys()
@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor):
help='Increase logger verbosity of current task.') help='Increase logger verbosity of current task.')
def _init_from_path(self, def _init_from_path(self,
model_type: str='wav2vec2ASR_librispeech', model_type: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,
@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.debug("start to init the model") logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
# default max_len: unit:second # default max_len: unit:second
self.max_len = 50 self.max_len = 50
if hasattr(self, 'model'): if hasattr(self, 'model'):
@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor):
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
if task == 'asr': if task == 'asr':
with UpdateConfig(self.config): with UpdateConfig(self.config):
if lang == 'en':
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath) vocab=self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor):
audio, audio,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size) beam_size=cfg.beam_size,
tokenizer=getattr(self.config, 'tokenizer', None))
self._outputs["result"] = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor):
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='wav2vec2ASR_librispeech', model: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,

@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
}, },
}, },
"wav2vec2-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz',
'md5':
'00ea4975c05d1bb58181205674052fe1',
'cfg_path':
'model.yaml',
'ckpt_path':
'chinese-wav2vec2-large',
'model':
'chinese-wav2vec2-large.pdparams',
'params':
'chinese-wav2vec2-large.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz',
'md5':
'ac8fa0a6345e6a7535f6fabb5e59e218',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
} }
# --------------------------------- # ---------------------------------

@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure():
self.proj_codevector_dim = config.proj_codevector_dim self.proj_codevector_dim = config.proj_codevector_dim
self.diversity_loss_weight = config.diversity_loss_weight self.diversity_loss_weight = config.diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = config.ctc_loss_reduction
self.ctc_zero_infinity = config.ctc_zero_infinity
# adapter # adapter
self.add_adapter = config.add_adapter self.add_adapter = config.add_adapter
self.adapter_kernel_size = config.adapter_kernel_size self.adapter_kernel_size = config.adapter_kernel_size

@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer):
feats: paddle.Tensor, feats: paddle.Tensor,
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
beam_size: int): beam_size: int,
tokenizer: str=None):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error( raise ValueError(
f'decoding mode {decoding_method} must be running with batch_size == 1' f"decoding mode {decoding_method} must be running with batch_size == 1"
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search': if decoding_method == 'ctc_greedy_search':
if tokenizer is None:
hyps = self.ctc_greedy_search(feats) hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps] res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one # ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible # result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
if tokenizer is None:
hyp = self.ctc_prefix_beam_search(feats, beam_size) hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp] res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else: else:
raise ValueError( raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}") f"wav2vec2 not support decoding method: {decoding_method}")

Loading…
Cancel
Save