support editing num_decode_left_chunks in cli and server

pull/2016/head
huangyuxin 3 years ago
parent 8641608f08
commit 6ebe476532

@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"

@ -32,7 +32,7 @@ asr_online:
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring" decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected continuous_decoding: True # enable continue decoding when endpoint detected
num_decoding_left_chunks: -1
am_predictor_conf: am_predictor_conf:
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True switch_ir_optim: True

@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks:
force_yes: True force_yes: True
device: 'cpu' # cpu or gpu:id device: 'cpu' # cpu or gpu:id

@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor):
'attention_rescoring' 'attention_rescoring'
], ],
help='only support transformer and conformer model') help='only support transformer and conformer model')
self.parser.add_argument(
'--num_decoding_left_chunks',
'-num_left',
type=str,
default=-1,
help='only support transformer and conformer model')
self.parser.add_argument( self.parser.add_argument(
'--ckpt_path', '--ckpt_path',
type=str, type=str,
@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor):
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
@ -129,6 +136,7 @@ class ASRExecutor(BaseExecutor):
logger.info("start to init the model") logger.info("start to init the model")
# default max_len: unit:second # default max_len: unit:second
self.max_len = 50 self.max_len = 50
assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.info('Model had been initialized.')
return return
@ -179,6 +187,7 @@ class ASRExecutor(BaseExecutor):
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
self.config.num_decoding_left_chunks = num_decoding_left_chunks
else: else:
raise Exception("wrong type") raise Exception("wrong type")
@ -451,6 +460,7 @@ class ASRExecutor(BaseExecutor):
config: os.PathLike=None, config: os.PathLike=None,
ckpt_path: os.PathLike=None, ckpt_path: os.PathLike=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
force_yes: bool=False, force_yes: bool=False,
rtf: bool=False, rtf: bool=False,
device=paddle.get_device()): device=paddle.get_device()):
@ -460,7 +470,7 @@ class ASRExecutor(BaseExecutor):
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path) num_decoding_left_chunks, ckpt_path)
if not self._check(audio_file, sample_rate, force_yes): if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1) sys.exit(-1)
if rtf: if rtf:

@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks:
force_yes: True force_yes: True
device: # cpu or gpu:id device: # cpu or gpu:id

@ -28,6 +28,7 @@ asr_online:
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
decode_method: decode_method:
num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: # cpu or gpu:id device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected continuous_decoding: True # enable continue decoding when endpoint detected

@ -703,6 +703,7 @@ class ASRServerExecutor(ASRExecutor):
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring', decode_method: str='attention_rescoring',
num_decoding_left_chunks: int=-1,
am_predictor_conf: dict=None): am_predictor_conf: dict=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
@ -788,7 +789,10 @@ class ASRServerExecutor(ASRExecutor):
# update the decoding method # update the decoding method
if decode_method: if decode_method:
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
# update num_decoding_left_chunks
if num_decoding_left_chunks:
self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks
assert self.config.decode.num_decoding_left_chunks == -1 or self.config.decode.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0"
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method # we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring # Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [ if self.config.decode.decoding_method not in [
@ -862,6 +866,7 @@ class ASREngine(BaseEngine):
sample_rate=self.config.sample_rate, sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path, cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method, decode_method=self.config.decode_method,
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
am_predictor_conf=self.config.am_predictor_conf): am_predictor_conf=self.config.am_predictor_conf):
logger.error( logger.error(
"Init the ASR server occurs error, please check the server configuration yaml" "Init the ASR server occurs error, please check the server configuration yaml"

Loading…
Cancel
Save