Merge pull request #1956 from zh794390558/asr_stream

[server] fix streaming asr
pull/1959/head
Hui Zhang 3 years ago committed by GitHub
commit 1a8b478b0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@
# SERVER SETTING # # SERVER SETTING #
################################################################################# #################################################################################
host: 0.0.0.0 host: 0.0.0.0
port: 8090 port: 8091
# The task format in the engin_list is: <speech task>_<engine type> # The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online'] # task choices = ['asr_online']

@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler:
logger.info( logger.info(
"create an paddle asr connection handler to process the websocket connection" "create an paddle asr connection handler to process the websocket connection"
) )
self.config = asr_engine.config self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine self.asr_engine = asr_engine
@ -249,10 +249,13 @@ class PaddleASRConnectionHanddler:
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# for deepspeech2 # for deepspeech2
self.chunk_state_h_box = copy.deepcopy( # init state
self.asr_engine.executor.chunk_state_h_box) self.chunk_state_h_box = np.zeros(
self.chunk_state_c_box = copy.deepcopy( (self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size),
self.asr_engine.executor.chunk_state_c_box) dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
self.device = None self.device = None
@ -803,36 +806,6 @@ class ASRServerExecutor(ASRExecutor):
model_file=self.am_model, model_file=self.am_model,
params_file=self.am_params, params_file=self.am_params,
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# init state box
self.chunk_state_h_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32)
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset} '_')] # model_type: {model_name}_{dataset}
@ -847,15 +820,11 @@ class ASRServerExecutor(ASRExecutor):
model_dict = paddle.load(self.am_model) model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success") logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
else: else:
raise ValueError(f"Not support: {model_type}") raise ValueError(f"Not support: {model_type}")
return True return True
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
"""ASR server resource """ASR server resource
@ -881,8 +850,8 @@ class ASREngine(BaseEngine):
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
try: try:
default_dev = paddle.get_device() self.device = self.config.get("device", paddle.get_device())
paddle.set_device(self.config.get("device", default_dev)) paddle.set_device(self.device)
except BaseException as e: except BaseException as e:
logger.error( logger.error(
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file" f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"

Loading…
Cancel
Save