|
|
@ -53,7 +53,7 @@ class PaddleASRConnectionHanddler:
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
"create an paddle asr connection handler to process the websocket connection"
|
|
|
|
"create an paddle asr connection handler to process the websocket connection"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.config = asr_engine.config
|
|
|
|
self.config = asr_engine.config # server config
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
|
|
|
|
|
|
|
@ -249,10 +249,13 @@ class PaddleASRConnectionHanddler:
|
|
|
|
def reset(self):
|
|
|
|
def reset(self):
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
# for deepspeech2
|
|
|
|
# for deepspeech2
|
|
|
|
self.chunk_state_h_box = copy.deepcopy(
|
|
|
|
# init state
|
|
|
|
self.asr_engine.executor.chunk_state_h_box)
|
|
|
|
self.chunk_state_h_box = np.zeros(
|
|
|
|
self.chunk_state_c_box = copy.deepcopy(
|
|
|
|
(self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size),
|
|
|
|
self.asr_engine.executor.chunk_state_c_box)
|
|
|
|
dtype=float32)
|
|
|
|
|
|
|
|
self.chunk_state_c_box = np.zeros(
|
|
|
|
|
|
|
|
(self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size),
|
|
|
|
|
|
|
|
dtype=float32)
|
|
|
|
self.decoder.reset_decoder(batch_size=1)
|
|
|
|
self.decoder.reset_decoder(batch_size=1)
|
|
|
|
|
|
|
|
|
|
|
|
self.device = None
|
|
|
|
self.device = None
|
|
|
@ -803,36 +806,6 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
model_file=self.am_model,
|
|
|
|
model_file=self.am_model,
|
|
|
|
params_file=self.am_params,
|
|
|
|
params_file=self.am_params,
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
|
|
|
logger.info("ASR engine start to create the ctc decoder instance")
|
|
|
|
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
|
|
|
|
odim=self.config.output_dim, # <blank> is in vocab
|
|
|
|
|
|
|
|
enc_n_units=self.config.rnn_layer_size * 2,
|
|
|
|
|
|
|
|
blank_id=self.config.blank_id,
|
|
|
|
|
|
|
|
dropout_rate=0.0,
|
|
|
|
|
|
|
|
reduction=True, # sum
|
|
|
|
|
|
|
|
batch_average=True, # sum / batch_size
|
|
|
|
|
|
|
|
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# init decoder
|
|
|
|
|
|
|
|
logger.info("ASR engine start to init the ctc decoder")
|
|
|
|
|
|
|
|
cfg = self.config.decode
|
|
|
|
|
|
|
|
decode_batch_size = 1 # for online
|
|
|
|
|
|
|
|
self.decoder.init_decoder(
|
|
|
|
|
|
|
|
decode_batch_size, self.text_feature.vocab_list,
|
|
|
|
|
|
|
|
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
|
|
|
|
|
|
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
|
|
|
|
|
|
|
cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# init state box
|
|
|
|
|
|
|
|
self.chunk_state_h_box = np.zeros(
|
|
|
|
|
|
|
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
|
|
|
|
|
|
|
dtype=float32)
|
|
|
|
|
|
|
|
self.chunk_state_c_box = np.zeros(
|
|
|
|
|
|
|
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
|
|
|
|
|
|
|
dtype=float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
@ -847,15 +820,11 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
model_dict = paddle.load(self.am_model)
|
|
|
|
model_dict = paddle.load(self.am_model)
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
logger.info("create the transformer like model success")
|
|
|
|
logger.info("create the transformer like model success")
|
|
|
|
|
|
|
|
|
|
|
|
# update the ctc decoding
|
|
|
|
|
|
|
|
self.searcher = CTCPrefixBeamSearch(self.config.decode)
|
|
|
|
|
|
|
|
self.transformer_decode_reset()
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Not support: {model_type}")
|
|
|
|
raise ValueError(f"Not support: {model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASREngine(BaseEngine):
|
|
|
|
class ASREngine(BaseEngine):
|
|
|
|
"""ASR server resource
|
|
|
|
"""ASR server resource
|
|
|
@ -881,8 +850,8 @@ class ASREngine(BaseEngine):
|
|
|
|
self.executor = ASRServerExecutor()
|
|
|
|
self.executor = ASRServerExecutor()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
default_dev = paddle.get_device()
|
|
|
|
self.device = self.config.get("device", paddle.get_device())
|
|
|
|
paddle.set_device(self.config.get("device", default_dev))
|
|
|
|
paddle.set_device(self.device)
|
|
|
|
except BaseException as e:
|
|
|
|
except BaseException as e:
|
|
|
|
logger.error(
|
|
|
|
logger.error(
|
|
|
|
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
|
|
|
|
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
|
|
|
|