|
|
@ -21,9 +21,21 @@ import paddle
|
|
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
from ..utils import cli_register
|
|
|
|
from ..utils import cli_register
|
|
|
|
|
|
|
|
from ..utils import download_and_decompress
|
|
|
|
|
|
|
|
from ..utils import logger
|
|
|
|
|
|
|
|
from ..utils import MODEL_HOME
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['S2TExecutor']
|
|
|
|
__all__ = ['S2TExecutor']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained_models = {
|
|
|
|
|
|
|
|
"wenetspeech_zh": {
|
|
|
|
|
|
|
|
'url':
|
|
|
|
|
|
|
|
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
|
|
|
|
|
|
|
|
'md5':
|
|
|
|
|
|
|
|
'54e7a558a6e020c2f5fb224874943f97',
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cli_register(
|
|
|
|
@cli_register(
|
|
|
|
name='paddlespeech.s2t', description='Speech to text infer command.')
|
|
|
|
name='paddlespeech.s2t', description='Speech to text infer command.')
|
|
|
@ -33,11 +45,23 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
prog='paddlespeech.s2t', add_help=True)
|
|
|
|
prog='paddlespeech.s2t', add_help=True)
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--model',
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default='wenetspeech',
|
|
|
|
|
|
|
|
help='Choose model type of asr task.')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--lang', type=str, default='zh', help='Choose model language.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--config',
|
|
|
|
'--config',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
help='Config of s2t task. Use deault config when it is None.')
|
|
|
|
help='Config of s2t task. Use deault config when it is None.')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--ckpt_path',
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
|
|
help='Checkpoint file of model.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--input', type=str, help='Audio file to recognize.')
|
|
|
|
'--input', type=str, help='Audio file to recognize.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
@ -46,16 +70,39 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
default='cpu',
|
|
|
|
default='cpu',
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
help='Choose device to execute model inference.')
|
|
|
|
|
|
|
|
|
|
|
|
def _get_default_cfg_path(self):
|
|
|
|
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Returns a default config file path of current task.
|
|
|
|
Download and returns pretrained resources path of current task.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
|
|
|
|
|
|
|
|
tag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res_path = os.path.join(MODEL_HOME, tag)
|
|
|
|
|
|
|
|
decompressed_path = download_and_decompress(pretrained_models[tag],
|
|
|
|
|
|
|
|
res_path)
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
'Use pretrained model stored in: {}'.format(decompressed_path))
|
|
|
|
|
|
|
|
return decompressed_path
|
|
|
|
|
|
|
|
|
|
|
|
def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None):
|
|
|
|
def _init_from_path(self,
|
|
|
|
|
|
|
|
model_type: str='wenetspeech',
|
|
|
|
|
|
|
|
lang: str='zh',
|
|
|
|
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
|
|
|
|
ckpt_path: Optional[os.PathLike]=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Init model from a specific config file.
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
if cfg_path is None or ckpt_path is None:
|
|
|
|
|
|
|
|
res_path = self._get_pretrained_path(
|
|
|
|
|
|
|
|
model_type + '_' + lang) # wenetspeech_zh
|
|
|
|
|
|
|
|
cfg_path = os.path.join(res_path, 'conf/conformer.yaml')
|
|
|
|
|
|
|
|
ckpt_path = os.path.join(
|
|
|
|
|
|
|
|
res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams')
|
|
|
|
|
|
|
|
logger.info(res_path)
|
|
|
|
|
|
|
|
logger.info(cfg_path)
|
|
|
|
|
|
|
|
logger.info(ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Init body.
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, input: Union[str, os.PathLike]):
|
|
|
|
def preprocess(self, input: Union[str, os.PathLike]):
|
|
|
@ -82,17 +129,15 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
parser_args = self.parser.parse_args(argv)
|
|
|
|
parser_args = self.parser.parse_args(argv)
|
|
|
|
print(parser_args)
|
|
|
|
print(parser_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = parser_args.model
|
|
|
|
|
|
|
|
lang = parser_args.lang
|
|
|
|
config = parser_args.config
|
|
|
|
config = parser_args.config
|
|
|
|
|
|
|
|
ckpt_path = parser_args.ckpt_path
|
|
|
|
audio_file = parser_args.input
|
|
|
|
audio_file = parser_args.input
|
|
|
|
device = parser_args.device
|
|
|
|
device = parser_args.device
|
|
|
|
|
|
|
|
|
|
|
|
if config is not None:
|
|
|
|
|
|
|
|
assert os.path.isfile(config), 'Config file is not valid.'
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
config = self._get_default_cfg_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
self._init_from_cfg(config)
|
|
|
|
self._init_from_path(model, lang, config, ckpt_path)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
self.infer()
|
|
|
|
self.infer()
|
|
|
|
res = self.postprocess() # Retrieve result of s2t.
|
|
|
|
res = self.postprocess() # Retrieve result of s2t.
|
|
|
|