From 6ebe4765320eea44b10bd5dd9730ae248c9a031b Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Wed, 8 Jun 2022 02:43:37 +0000 Subject: [PATCH] support editing num_decode_left_chunks in cli and server --- .../conf/ws_conformer_application.yaml | 1 + .../conf/ws_conformer_wenetspeech_application.yaml | 2 +- .../conf/ws_ds2_application.yaml | 1 + paddlespeech/cli/asr/infer.py | 12 +++++++++++- paddlespeech/server/conf/ws_application.yaml | 1 + .../server/conf/ws_conformer_application.yaml | 3 ++- paddlespeech/server/engine/asr/online/asr_engine.py | 7 ++++++- 7 files changed, 23 insertions(+), 4 deletions(-) diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 9dbc82b6..01bb1e9c 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml index 683d86f0..d30bcd02 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml @@ -32,7 +32,7 @@ asr_online: device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" continuous_decoding: True # enable continue decoding when endpoint detected - + num_decoding_left_chunks: -1 am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index f2ea6330..d19bd26d 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True device: 'cpu' # cpu or gpu:id diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index f26901a1..ad83bc20 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -83,6 +83,12 @@ class ASRExecutor(BaseExecutor): 'attention_rescoring' ], help='only support transformer and conformer model') + self.parser.add_argument( + '--num_decoding_left_chunks', + '-num_left', + type=str, + default=-1, + help='only support transformer and conformer model') self.parser.add_argument( '--ckpt_path', type=str, @@ -122,6 +128,7 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. @@ -129,6 +136,7 @@ class ASRExecutor(BaseExecutor): logger.info("start to init the model") # default max_len: unit:second self.max_len = 50 + assert num_decoding_left_chunks == -1 or num_decoding_left_chunks >= 0 if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -179,6 +187,7 @@ class ASRExecutor(BaseExecutor): elif "conformer" in model_type or "transformer" in model_type: self.config.decode.decoding_method = decode_method + self.config.num_decoding_left_chunks = num_decoding_left_chunks else: raise Exception("wrong type") @@ -451,6 +460,7 @@ class ASRExecutor(BaseExecutor): config: os.PathLike=None, ckpt_path: os.PathLike=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, force_yes: bool=False, rtf: bool=False, device=paddle.get_device()): @@ -460,7 +470,7 @@ class ASRExecutor(BaseExecutor): audio_file = os.path.abspath(audio_file) paddle.set_device(device) self._init_from_path(model, lang, sample_rate, config, decode_method, - ckpt_path) + num_decoding_left_chunks, ckpt_path) if not self._check(audio_file, sample_rate, force_yes): sys.exit(-1) if rtf: diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index d6f5a227..43d83f2d 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: force_yes: True device: # cpu or gpu:id diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index dd5e67ca..d72eb237 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -28,6 +28,7 @@ asr_online: sample_rate: 16000 cfg_path: decode_method: + num_decoding_left_chunks: -1 force_yes: True device: # cpu or gpu:id continuous_decoding: True # enable continue decoding when endpoint detected @@ -44,4 +45,4 @@ asr_online: window_ms: 25 # ms shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 \ No newline at end of file + sample_width: 2 diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 8fc210e5..3eefa9d7 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -703,6 +703,7 @@ class ASRServerExecutor(ASRExecutor): sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, decode_method: str='attention_rescoring', + num_decoding_left_chunks: int=-1, am_predictor_conf: dict=None): """ Init model and other resources from a specific path. @@ -788,7 +789,10 @@ class ASRServerExecutor(ASRExecutor): # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method - + # update num_decoding_left_chunks + if num_decoding_left_chunks: + self.config.decode.num_decoding_left_chunks = num_decoding_left_chunks + assert self.config.decode.num_decoding_left_chunks == -1 or self.config.decode.num_decoding_left_chunks >= 0, f"num_decoding_left_chunks should be -1 or >=0" # we only support ctc_prefix_beam_search and attention_rescoring dedoding method # Generally we set the decoding_method to attention_rescoring if self.config.decode.decoding_method not in [ @@ -862,6 +866,7 @@ class ASREngine(BaseEngine): sample_rate=self.config.sample_rate, cfg_path=self.config.cfg_path, decode_method=self.config.decode_method, + num_decoding_left_chunks=self.config.num_decoding_left_chunks, am_predictor_conf=self.config.am_predictor_conf): logger.error( "Init the ASR server occurs error, please check the server configuration yaml"