|
|
|
@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
assert (len(input_names) == len(output_names))
|
|
|
|
|
assert isinstance(input_names[0], str)
|
|
|
|
|
|
|
|
|
|
input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk]
|
|
|
|
|
input_datas = [
|
|
|
|
|
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
|
|
|
|
|
x_chunk
|
|
|
|
|
]
|
|
|
|
|
feeds = dict(zip(input_names, input_datas))
|
|
|
|
|
|
|
|
|
|
outputs = self.am_predictor.run(
|
|
|
|
|
[*output_names],
|
|
|
|
|
{**feeds})
|
|
|
|
|
outputs = self.am_predictor.run([*output_names], {**feeds})
|
|
|
|
|
|
|
|
|
|
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
|
|
|
|
|
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
|
|
|
@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.task_resource = CommonTaskResource(
|
|
|
|
|
task='asr', model_format='static', inference_mode='online')
|
|
|
|
|
task='asr', model_format='onnx', inference_mode='online')
|
|
|
|
|
|
|
|
|
|
def update_config(self) -> None:
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.res_path = os.path.dirname(
|
|
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
|
|
|
|
|
|
self.am_model = os.path.join(self.res_path,
|
|
|
|
|
self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model)
|
|
|
|
|
self.am_params = os.path.join(self.res_path,
|
|
|
|
|
self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params)
|
|
|
|
|
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
|
|
|
|
|
'model']) if am_model is None else os.path.abspath(am_model)
|
|
|
|
|
self.am_params = os.path.join(
|
|
|
|
|
self.res_path, self.task_resource.res_dict[
|
|
|
|
|
'params']) if am_params is None else os.path.abspath(am_params)
|
|
|
|
|
|
|
|
|
|
logger.info("Load the pretrained model:")
|
|
|
|
|
logger.info(f" tag = {tag}")
|
|
|
|
|