|
|
|
@ -51,11 +51,8 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--model',
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
choices=[
|
|
|
|
|
tag[:tag.index('-')]
|
|
|
|
|
for tag in self.task_resource.pretrained_models.keys()
|
|
|
|
|
],
|
|
|
|
|
default='wav2vec2',
|
|
|
|
|
choices=['wav2vec2', 'hubert'],
|
|
|
|
|
help='Choose model type of asr task.')
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
'--task',
|
|
|
|
@ -67,7 +64,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
'--lang',
|
|
|
|
|
type=str,
|
|
|
|
|
default='en',
|
|
|
|
|
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]'
|
|
|
|
|
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k, hubertASR_librispeech_100-en-16k]'
|
|
|
|
|
)
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
"--sample_rate",
|
|
|
|
@ -137,13 +134,6 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
logger.debug("start to init the model")
|
|
|
|
|
|
|
|
|
|
if model_type is None:
|
|
|
|
|
if lang == 'en':
|
|
|
|
|
model_type = 'wav2vec2ASR_librispeech'
|
|
|
|
|
elif lang == 'zh':
|
|
|
|
|
model_type = 'wav2vec2ASR_aishell1'
|
|
|
|
|
else:
|
|
|
|
|
logger.error(
|
|
|
|
|
"invalid lang, please input --lang en or --lang zh")
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Model type had not been specified, default {} was used.".
|
|
|
|
|
format(model_type))
|
|
|
|
@ -155,9 +145,20 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
|
if task == 'asr':
|
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
if model_type == 'wav2vec2':
|
|
|
|
|
if lang == 'en':
|
|
|
|
|
model_prefix = 'wav2vec2ASR_librispeech'
|
|
|
|
|
elif lang == 'zh':
|
|
|
|
|
model_prefix = 'wav2vec2ASR_aishell1'
|
|
|
|
|
tag = model_prefix + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
elif model_type == 'hubert':
|
|
|
|
|
if lang == 'en':
|
|
|
|
|
model_prefix = 'hubertASR_librispeech_100'
|
|
|
|
|
elif lang == 'zh':
|
|
|
|
|
logger.error("zh hubertASR is not supported yet")
|
|
|
|
|
tag = model_prefix + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
else:
|
|
|
|
|
tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
self.task_resource.set_task_model(tag, version=None)
|
|
|
|
|
self.res_path = self.task_resource.res_dir
|
|
|
|
|
|
|
|
|
@ -191,7 +192,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
|
else:
|
|
|
|
|
model_name = 'wav2vec2'
|
|
|
|
|
model_name = model_type
|
|
|
|
|
model_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
|
|
|
|
|
model_conf = self.config
|
|
|
|
@ -204,9 +205,9 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
if task == 'asr':
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
else:
|
|
|
|
|
self.model.wav2vec2.set_state_dict(model_dict)
|
|
|
|
|
getattr(self.model, model_type).set_state_dict(model_dict)
|
|
|
|
|
|
|
|
|
|
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
|
|
|
|
|
def preprocess(self, input: Union[str, os.PathLike]):
|
|
|
|
|
"""
|
|
|
|
|
Input preprocess and return paddle.Tensor stored in self.input.
|
|
|
|
|
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
|
|
|
|
@ -264,7 +265,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
if task == 'asr':
|
|
|
|
|
cfg = self.config.decode
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"we will use the wav2vec2ASR like model : {model_type}")
|
|
|
|
|
f"we will use the {model_type}ASR like model.")
|
|
|
|
|
try:
|
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
|
audio,
|
|
|
|
@ -277,7 +278,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
|
logger.exception(e)
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(
|
|
|
|
|
"we will use the wav2vec2 like model to extract audio feature")
|
|
|
|
|
f"we will use the {model_type} like model to extract audio feature.")
|
|
|
|
|
try:
|
|
|
|
|
out_feature = self.model(audio[:, :, 0])
|
|
|
|
|
self._outputs["result"] = out_feature[0]
|
|
|
|
|