|
|
@ -39,7 +39,11 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
__all__ = ['ASRExecutor']
|
|
|
|
__all__ = ['ASRExecutor']
|
|
|
|
|
|
|
|
|
|
|
|
pretrained_models = {
|
|
|
|
pretrained_models = {
|
|
|
|
"wenetspeech_zh_16k": {
|
|
|
|
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
|
|
|
|
|
|
|
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
|
|
|
|
|
|
|
|
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
|
|
|
|
|
|
|
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
|
|
|
|
|
|
|
"conformer_wenetspeech-zh-16k": {
|
|
|
|
'url':
|
|
|
|
'url':
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
|
|
|
|
'md5':
|
|
|
|
'md5':
|
|
|
@ -49,7 +53,7 @@ pretrained_models = {
|
|
|
|
'ckpt_path':
|
|
|
|
'ckpt_path':
|
|
|
|
'exp/conformer/checkpoints/wenetspeech',
|
|
|
|
'exp/conformer/checkpoints/wenetspeech',
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"transformer_zh_16k": {
|
|
|
|
"transformer_aishell-zh-16k": {
|
|
|
|
'url':
|
|
|
|
'url':
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
|
|
|
|
'md5':
|
|
|
|
'md5':
|
|
|
@ -83,7 +87,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--model',
|
|
|
|
'--model',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default='wenetspeech',
|
|
|
|
default='conformer_wenetspeech',
|
|
|
|
help='Choose model type of asr task.')
|
|
|
|
help='Choose model type of asr task.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--lang',
|
|
|
|
'--lang',
|
|
|
@ -143,7 +147,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
tag = model_type + '_' + lang + '_' + sample_rate_str
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
|
|
|
self.res_path = res_path
|
|
|
|
self.res_path = res_path
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
|
self.cfg_path = os.path.join(res_path,
|
|
|
@ -165,7 +169,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
self.config.decoding.decoding_method = "attention_rescoring"
|
|
|
|
self.config.decoding.decoding_method = "attention_rescoring"
|
|
|
|
|
|
|
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
if "ds2_online" in model_type or "ds2_offline" in model_type:
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
@ -178,7 +182,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
self.config.model.input_dim = self.collate_fn_test.feature_size
|
|
|
|
self.config.model.input_dim = self.collate_fn_test.feature_size
|
|
|
|
self.config.model.output_dim = text_feature.vocab_size
|
|
|
|
self.config.model.output_dim = text_feature.vocab_size
|
|
|
|
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
self.config.collator.vocab_filepath = os.path.join(
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
|
res_path, self.config.collator.vocab_filepath)
|
|
|
|
self.config.collator.augmentation_config = os.path.join(
|
|
|
|
self.config.collator.augmentation_config = os.path.join(
|
|
|
@ -196,7 +200,9 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
raise Exception("wrong type")
|
|
|
|
raise Exception("wrong type")
|
|
|
|
# Enter the path of model root
|
|
|
|
# Enter the path of model root
|
|
|
|
|
|
|
|
|
|
|
|
model_class = dynamic_import(model_type, model_alias)
|
|
|
|
model_name = ''.join(
|
|
|
|
|
|
|
|
model_type.split('_')[:-1]) # model_type: {model_name}_{dataset}
|
|
|
|
|
|
|
|
model_class = dynamic_import(model_name, model_alias)
|
|
|
|
model_conf = self.config.model
|
|
|
|
model_conf = self.config.model
|
|
|
|
logger.info(model_conf)
|
|
|
|
logger.info(model_conf)
|
|
|
|
model = model_class.from_config(model_conf)
|
|
|
|
model = model_class.from_config(model_conf)
|
|
|
@ -217,7 +223,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
logger.info("Preprocess audio_file:" + audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
# Get the object for feature extraction
|
|
|
|
# Get the object for feature extraction
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
if "ds2_online" in model_type or "ds2_offline" in model_type:
|
|
|
|
audio, _ = self.collate_fn_test.process_utterance(
|
|
|
|
audio, _ = self.collate_fn_test.process_utterance(
|
|
|
|
audio_file=audio_file, transcript=" ")
|
|
|
|
audio_file=audio_file, transcript=" ")
|
|
|
|
audio_len = audio.shape[0]
|
|
|
|
audio_len = audio.shape[0]
|
|
|
@ -229,7 +235,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
self._inputs["audio_len"] = audio_len
|
|
|
|
self._inputs["audio_len"] = audio_len
|
|
|
|
logger.info(f"audio feat shape: {audio.shape}")
|
|
|
|
logger.info(f"audio feat shape: {audio.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
|
|
|
logger.info("get the preprocess conf")
|
|
|
|
logger.info("get the preprocess conf")
|
|
|
|
preprocess_conf_file = self.config.collator.augmentation_config
|
|
|
|
preprocess_conf_file = self.config.collator.augmentation_config
|
|
|
|
# redirect the cmvn path
|
|
|
|
# redirect the cmvn path
|
|
|
@ -293,7 +299,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
cfg = self.config.decoding
|
|
|
|
cfg = self.config.decoding
|
|
|
|
audio = self._inputs["audio"]
|
|
|
|
audio = self._inputs["audio"]
|
|
|
|
audio_len = self._inputs["audio_len"]
|
|
|
|
audio_len = self._inputs["audio_len"]
|
|
|
|
if model_type == "ds2_online" or model_type == "ds2_offline":
|
|
|
|
if "ds2_online" in model_type or "ds2_offline" in model_type:
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
audio_len,
|
|
|
|
audio_len,
|
|
|
@ -308,7 +314,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
num_processes=cfg.num_proc_bsearch)
|
|
|
|
self._outputs["result"] = result_transcripts[0]
|
|
|
|
self._outputs["result"] = result_transcripts[0]
|
|
|
|
|
|
|
|
|
|
|
|
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
audio_len,
|
|
|
|
audio_len,
|
|
|
|