support wav2vec2-zh cli, test=asr

pull/2697/head
tianhao zhang 3 years ago
parent bd01bc155d
commit d9f28a87d4

@ -25,6 +25,7 @@ import librosa
import numpy as np import numpy as np
import paddle import paddle
import soundfile import soundfile
import transformers
from yacs.config import CfgNode from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='wav2vec2ASR_librispeech', default=None,
choices=[ choices=[
tag[:tag.index('-')] tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys() for tag in self.task_resource.pretrained_models.keys()
@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor):
help='Increase logger verbosity of current task.') help='Increase logger verbosity of current task.')
def _init_from_path(self, def _init_from_path(self,
model_type: str='wav2vec2ASR_librispeech', model_type: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,
@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor):
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.debug("start to init the model") logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = wav2vec2ASR_librispeech
elif lang == 'zh':
model_type = wav2vec2ASR_aishell1
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
# default max_len: unit:second # default max_len: unit:second
self.max_len = 50 self.max_len = 50
if hasattr(self, 'model'): if hasattr(self, 'model'):
@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor):
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
if task == 'asr': if task == 'asr':
with UpdateConfig(self.config): with UpdateConfig(self.config):
if lang == 'en':
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath) vocab=self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = transformers.BertTokenizer.from_pretrained(
self.config.tokenizer)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor):
audio, audio,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size) beam_size=cfg.beam_size,
tokenizer=getattr(self.config, 'tokenizer', None))
self._outputs["result"] = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor):
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='wav2vec2ASR_librispeech', model: str=None,
task: str='asr', task: str='asr',
lang: str='en', lang: str='en',
sample_rate: int=16000, sample_rate: int=16000,

@ -70,6 +70,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams', 'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
}, },
}, },
"wav2vec2-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz',
'md5':
'00ea4975c05d1bb58181205674052fe1',
'cfg_path':
'model.yaml',
'ckpt_path':
'chinese-wav2vec2-large',
'model':
'chinese-wav2vec2-large.pdparams',
'params':
'chinese-wav2vec2-large.pdparams',
},
},
"wav2vec2ASR_aishell1-zh-16k": {
'1.3': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.3.0.model.tar.gz',
'md5':
'ac8fa0a6345e6a7535f6fabb5e59e218',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/wav2vec2ASR/checkpoints/avg_1',
'model':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
'params':
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
} }
# --------------------------------- # ---------------------------------

@ -1173,10 +1173,6 @@ class Wav2Vec2ConfigPure():
self.proj_codevector_dim = config.proj_codevector_dim self.proj_codevector_dim = config.proj_codevector_dim
self.diversity_loss_weight = config.diversity_loss_weight self.diversity_loss_weight = config.diversity_loss_weight
# ctc loss
self.ctc_loss_reduction = config.ctc_loss_reduction
self.ctc_zero_infinity = config.ctc_zero_infinity
# adapter # adapter
self.add_adapter = config.add_adapter self.add_adapter = config.add_adapter
self.adapter_kernel_size = config.adapter_kernel_size self.adapter_kernel_size = config.adapter_kernel_size

@ -76,28 +76,66 @@ class Wav2vec2ASR(nn.Layer):
feats: paddle.Tensor, feats: paddle.Tensor,
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
beam_size: int): beam_size: int,
tokenizer: str=None):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error( raise ValueError(
f'decoding mode {decoding_method} must be running with batch_size == 1' f"decoding mode {decoding_method} must be running with batch_size == 1"
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search': if decoding_method == 'ctc_greedy_search':
if tokenizer is None:
hyps = self.ctc_greedy_search(feats) hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps] res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one # ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible # result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
if tokenizer is None:
hyp = self.ctc_prefix_beam_search(feats, beam_size) hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp] res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else: else:
raise ValueError( raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}") f"wav2vec2 not support decoding method: {decoding_method}")

Loading…
Cancel
Save