add the call in infer.py

pull/1048/head
huangyuxin 3 years ago
commit 43f4d47bfa

@ -14,6 +14,7 @@
import os import os
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import List
from typing import Union from typing import Union
import paddle import paddle
@ -64,3 +65,17 @@ class BaseExecutor(ABC):
Output postprocess and return human-readable results such as texts and audio files. Output postprocess and return human-readable results such as texts and audio files.
""" """
pass pass
@abstractmethod
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
pass
@abstractmethod
def __call__(self, *arg, **kwargs):
"""
Python API to call an executor.
"""
pass

@ -18,13 +18,14 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import soundfile
import paddle import paddle
from paddlespeech.cli.executor import BaseExecutor import soundfile
from paddlespeech.cli.utils import cli_register
from paddlespeech.cli.utils import download_and_decompress from ..executor import BaseExecutor
from paddlespeech.cli.utils import logger from ..utils import cli_register
from paddlespeech.cli.utils import MODEL_HOME from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
@ -55,29 +56,6 @@ model_alias = {
"wenetspeech": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": "paddlespeech.s2t.models.u2:U2Model",
} }
pretrain_model_alias = {
"ds2_online_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/aishell_ds2_online_cer8.00_release.tar.gz",
"", ""
],
"ds2_offline_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/ds2.model.tar.gz",
"", ""
],
"transformer_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz",
"", ""
],
"conformer_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz",
"", ""
],
"wenetspeech_zn": [
"https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz",
"conf/conformer.yaml", "exp/conformer/checkpoints/wenetspeech"
],
}
@cli_register( @cli_register(
name='paddlespeech.s2t', description='Speech to text infer command.') name='paddlespeech.s2t', description='Speech to text infer command.')
@ -107,7 +85,6 @@ class S2TExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--input', '--input',
type=str, type=str,
default="../Downloads/asr-demo-1.wav",
help='Audio file to recognize.') help='Audio file to recognize.')
self.parser.add_argument( self.parser.add_argument(
'--device', '--device',
@ -155,7 +132,9 @@ class S2TExecutor(BaseExecutor):
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
# Enter the path of model root
os.chdir(res_path) os.chdir(res_path)
#Init body. #Init body.
parser_args = self.parser_args parser_args = self.parser_args
paddle.set_device(parser_args.device) paddle.set_device(parser_args.device)
@ -206,7 +185,7 @@ class S2TExecutor(BaseExecutor):
config = self.config config = self.config
audio_file = input audio_file = input
#print("audio_file", audio_file) #print("audio_file", audio_file)
logger.info("audio_file"+ audio_file) logger.info("audio_file" + audio_file)
self.sr = config.collator.target_sample_rate self.sr = config.collator.target_sample_rate
@ -307,7 +286,11 @@ class S2TExecutor(BaseExecutor):
return self.result_transcripts return self.result_transcripts
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
self.parser_args = self.parser.parse_args(argv) self.parser_args = self.parser.parse_args(argv)
print(self.parser_args)
model = self.parser_args.model model = self.parser_args.model
lang = self.parser_args.lang lang = self.parser_args.lang
@ -317,17 +300,20 @@ class S2TExecutor(BaseExecutor):
device = self.parser_args.device device = self.parser_args.device
try: try:
self._init_from_path(model, lang, config, ckpt_path) res = self(model, lang, config, ckpt_path, audio_file, device)
self.preprocess(audio_file) print(res)
self.infer()
res = self.postprocess() # Retrieve result of s2t.
logger.info(res)
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, config, ckpt_path, audio_file, device):
"""
Python API to call an executor.
"""
self._init_from_path(model, lang, config, ckpt_path)
self.preprocess(audio_file)
self.infer()
res = self.postprocess() # Retrieve result of s2t.
if __name__ == "__main__": return res
exe = S2TExecutor()
exe.execute('')

Loading…
Cancel
Save