|
|
@ -25,6 +25,7 @@ import librosa
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
import soundfile
|
|
|
|
import soundfile
|
|
|
|
|
|
|
|
from paddlenlp.transformers import AutoTokenizer
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
from ..executor import BaseExecutor
|
|
|
@ -50,7 +51,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--model',
|
|
|
|
'--model',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default='wav2vec2ASR_librispeech',
|
|
|
|
default=None,
|
|
|
|
choices=[
|
|
|
|
choices=[
|
|
|
|
tag[:tag.index('-')]
|
|
|
|
tag[:tag.index('-')]
|
|
|
|
for tag in self.task_resource.pretrained_models.keys()
|
|
|
|
for tag in self.task_resource.pretrained_models.keys()
|
|
|
@ -123,7 +124,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
help='Increase logger verbosity of current task.')
|
|
|
|
help='Increase logger verbosity of current task.')
|
|
|
|
|
|
|
|
|
|
|
|
def _init_from_path(self,
|
|
|
|
def _init_from_path(self,
|
|
|
|
model_type: str='wav2vec2ASR_librispeech',
|
|
|
|
model_type: str=None,
|
|
|
|
task: str='asr',
|
|
|
|
task: str='asr',
|
|
|
|
lang: str='en',
|
|
|
|
lang: str='en',
|
|
|
|
sample_rate: int=16000,
|
|
|
|
sample_rate: int=16000,
|
|
|
@ -134,6 +135,18 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
logger.debug("start to init the model")
|
|
|
|
logger.debug("start to init the model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_type is None:
|
|
|
|
|
|
|
|
if lang == 'en':
|
|
|
|
|
|
|
|
model_type = 'wav2vec2ASR_librispeech'
|
|
|
|
|
|
|
|
elif lang == 'zh':
|
|
|
|
|
|
|
|
model_type = 'wav2vec2ASR_aishell1'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
logger.error(
|
|
|
|
|
|
|
|
"invalid lang, please input --lang en or --lang zh")
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
"Model type had not been specified, default {} was used.".
|
|
|
|
|
|
|
|
format(model_type))
|
|
|
|
# default max_len: unit:second
|
|
|
|
# default max_len: unit:second
|
|
|
|
self.max_len = 50
|
|
|
|
self.max_len = 50
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
if hasattr(self, 'model'):
|
|
|
@ -167,9 +180,13 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
self.config.merge_from_file(self.cfg_path)
|
|
|
|
self.config.merge_from_file(self.cfg_path)
|
|
|
|
if task == 'asr':
|
|
|
|
if task == 'asr':
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
if lang == 'en':
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
unit_type=self.config.unit_type,
|
|
|
|
unit_type=self.config.unit_type,
|
|
|
|
vocab=self.config.vocab_filepath)
|
|
|
|
vocab=self.config.vocab_filepath)
|
|
|
|
|
|
|
|
elif lang == 'zh':
|
|
|
|
|
|
|
|
self.text_feature = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
|
self.config.tokenizer)
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
@ -253,7 +270,8 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
audio,
|
|
|
|
audio,
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
text_feature=self.text_feature,
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
decoding_method=cfg.decoding_method,
|
|
|
|
beam_size=cfg.beam_size)
|
|
|
|
beam_size=cfg.beam_size,
|
|
|
|
|
|
|
|
tokenizer=getattr(self.config, 'tokenizer', None))
|
|
|
|
self._outputs["result"] = result_transcripts[0][0]
|
|
|
|
self._outputs["result"] = result_transcripts[0][0]
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.exception(e)
|
|
|
|
logger.exception(e)
|
|
|
@ -413,7 +431,7 @@ class SSLExecutor(BaseExecutor):
|
|
|
|
@stats_wrapper
|
|
|
|
@stats_wrapper
|
|
|
|
def __call__(self,
|
|
|
|
def __call__(self,
|
|
|
|
audio_file: os.PathLike,
|
|
|
|
audio_file: os.PathLike,
|
|
|
|
model: str='wav2vec2ASR_librispeech',
|
|
|
|
model: str=None,
|
|
|
|
task: str='asr',
|
|
|
|
task: str='asr',
|
|
|
|
lang: str='en',
|
|
|
|
lang: str='en',
|
|
|
|
sample_rate: int=16000,
|
|
|
|
sample_rate: int=16000,
|
|
|
|