|
|
@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
'attention_rescoring'
|
|
|
|
'attention_rescoring'
|
|
|
|
],
|
|
|
|
],
|
|
|
|
help='only support transformer and conformer model')
|
|
|
|
help='only support transformer and conformer model')
|
|
|
|
|
|
|
|
self.parser.add_argument(
|
|
|
|
|
|
|
|
'--num_decoding_left_chunks',
|
|
|
|
|
|
|
|
'-num_left',
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default=-1,
|
|
|
|
|
|
|
|
help='only support transformer and conformer model')
|
|
|
|
self.parser.add_argument(
|
|
|
|
self.parser.add_argument(
|
|
|
|
'--ckpt_path',
|
|
|
|
'--ckpt_path',
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
sample_rate: int=16000,
|
|
|
|
sample_rate: int=16000,
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
cfg_path: Optional[os.PathLike]=None,
|
|
|
|
decode_method: str='attention_rescoring',
|
|
|
|
decode_method: str='attention_rescoring',
|
|
|
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
ckpt_path: Optional[os.PathLike]=None):
|
|
|
|
ckpt_path: Optional[os.PathLike]=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Init model and other resources from a specific path.
|
|
|
|
Init model and other resources from a specific path.
|
|
|
@ -129,6 +136,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
logger.info("start to init the model")
|
|
|
|
logger.info("start to init the model")
|
|
|
|
# default max_len: unit:second
|
|
|
|
# default max_len: unit:second
|
|
|
|
self.max_len = 50
|
|
|
|
self.max_len = 50
|
|
|
|
|
|
|
|
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
logger.info('Model had been initialized.')
|
|
|
|
logger.info('Model had been initialized.')
|
|
|
|
return
|
|
|
|
return
|
|
|
@ -179,6 +187,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
|
|
|
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
self.config.decode.decoding_method = decode_method
|
|
|
|
|
|
|
|
self.config.num_decoding_left_chunks = num_decoding_left_chunks
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise Exception("wrong type")
|
|
|
|
raise Exception("wrong type")
|
|
|
@ -451,6 +460,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
config: os.PathLike=None,
|
|
|
|
config: os.PathLike=None,
|
|
|
|
ckpt_path: os.PathLike=None,
|
|
|
|
ckpt_path: os.PathLike=None,
|
|
|
|
decode_method: str='attention_rescoring',
|
|
|
|
decode_method: str='attention_rescoring',
|
|
|
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
force_yes: bool=False,
|
|
|
|
force_yes: bool=False,
|
|
|
|
rtf: bool=False,
|
|
|
|
rtf: bool=False,
|
|
|
|
device=paddle.get_device()):
|
|
|
|
device=paddle.get_device()):
|
|
|
@ -460,7 +470,7 @@ class ASRExecutor(BaseExecutor):
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
audio_file = os.path.abspath(audio_file)
|
|
|
|
paddle.set_device(device)
|
|
|
|
paddle.set_device(device)
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, decode_method,
|
|
|
|
self._init_from_path(model, lang, sample_rate, config, decode_method,
|
|
|
|
ckpt_path)
|
|
|
|
num_decoding_left_chunks, ckpt_path)
|
|
|
|
if not self._check(audio_file, sample_rate, force_yes):
|
|
|
|
if not self._check(audio_file, sample_rate, force_yes):
|
|
|
|
sys.exit(-1)
|
|
|
|
sys.exit(-1)
|
|
|
|
if rtf:
|
|
|
|
if rtf:
|
|
|
|