|
|
@ -33,7 +33,7 @@ from paddlespeech.s2t.transform.transformation import Transformation
|
|
|
|
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
|
|
|
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
|
|
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['S2TExecutor']
|
|
|
|
__all__ = ['ASRExecutor']
|
|
|
|
|
|
|
|
|
|
|
|
pretrained_models = {
|
|
|
|
pretrained_models = {
|
|
|
|
"wenetspeech_zh": {
|
|
|
|
"wenetspeech_zh": {
|
|
|
@ -58,13 +58,15 @@ model_alias = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cli_register(
|
|
|
|
@cli_register(
|
|
|
|
name='paddlespeech.s2t', description='Speech to text infer command.')
|
|
|
|
name='paddlespeech.asr', description='Speech to text infer command.')
|
|
|
|
class S2TExecutor(BaseExecutor):
|
|
|
|
class ASRExecutor(BaseExecutor):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super(S2TExecutor, self).__init__()
|
|
|
|
super(ASRExecutor, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
self.parser = argparse.ArgumentParser(
|
|
|
|
prog='paddlespeech.s2t', add_help=True)
|
|
|
|
prog='paddlespeech.asr', add_help=True)
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--input', type=str, required=True, help='Audio file to recognize.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--model',
|
|
|
|
'--model',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
@ -76,16 +78,12 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
'--config',
|
|
|
|
'--config',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
help='Config of s2t task. Use deault config when it is None.')
|
|
|
|
help='Config of asr task. Use deault config when it is None.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--ckpt_path',
|
|
|
|
'--ckpt_path',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
help='Checkpoint file of model.')
|
|
|
|
help='Checkpoint file of model.')
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--input',
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
help='Audio file to recognize.')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--device',
|
|
|
|
'--device',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
@ -178,13 +176,12 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
def preprocess(self, input: Union[str, os.PathLike]):
|
|
|
|
def preprocess(self, input: Union[str, os.PathLike]):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Input preprocess and return paddle.Tensor stored in self.input.
|
|
|
|
Input preprocess and return paddle.Tensor stored in self.input.
|
|
|
|
Input content can be a text(t2s), a file(s2t, cls) or a streaming(not supported yet).
|
|
|
|
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
parser_args = self.parser_args
|
|
|
|
parser_args = self.parser_args
|
|
|
|
config = self.config
|
|
|
|
config = self.config
|
|
|
|
audio_file = input
|
|
|
|
audio_file = input
|
|
|
|
#print("audio_file", audio_file)
|
|
|
|
|
|
|
|
logger.info("audio_file" + audio_file)
|
|
|
|
logger.info("audio_file" + audio_file)
|
|
|
|
|
|
|
|
|
|
|
|
self.sr = config.collator.target_sample_rate
|
|
|
|
self.sr = config.collator.target_sample_rate
|
|
|
@ -290,7 +287,6 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
Command line entry.
|
|
|
|
Command line entry.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
self.parser_args = self.parser.parse_args(argv)
|
|
|
|
self.parser_args = self.parser.parse_args(argv)
|
|
|
|
print(self.parser_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = self.parser_args.model
|
|
|
|
model = self.parser_args.model
|
|
|
|
lang = self.parser_args.lang
|
|
|
|
lang = self.parser_args.lang
|
|
|
@ -301,7 +297,7 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
res = self(model, lang, config, ckpt_path, audio_file, device)
|
|
|
|
res = self(model, lang, config, ckpt_path, audio_file, device)
|
|
|
|
print(res)
|
|
|
|
logger.info('ASR Result: {}'.format(res))
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
print(e)
|
|
|
@ -314,6 +310,6 @@ class S2TExecutor(BaseExecutor):
|
|
|
|
self._init_from_path(model, lang, config, ckpt_path)
|
|
|
|
self._init_from_path(model, lang, config, ckpt_path)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
self.preprocess(audio_file)
|
|
|
|
self.infer()
|
|
|
|
self.infer()
|
|
|
|
res = self.postprocess() # Retrieve result of s2t.
|
|
|
|
res = self.postprocess() # Retrieve result of asr.
|
|
|
|
|
|
|
|
|
|
|
|
return res
|
|
|
|
return res
|