add ds2 steaming asr onnx

pull/2036/head
Hui Zhang 3 years ago
parent 9106daa2a3
commit 5e03d753ac

@ -22,8 +22,8 @@ engine_list: ['asr_online-inference']
################### speech task: asr; engine_type: online-inference ####################### ################### speech task: asr; engine_type: online-inference #######################
asr_online-inference: asr_online-inference:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -54,7 +54,7 @@ asr_online-inference:
################### speech task: asr; engine_type: online-onnx ####################### ################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx: asr_online-onnx:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000

@ -168,7 +168,6 @@ class CommonTaskResource:
exec('from .pretrained_models import {}'.format(import_models)) exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models]) models = OrderedDict(locals()[import_models])
except Exception as e: except Exception as e:
print(e)
models = OrderedDict({}) # no models. models = OrderedDict({}) # no models.
finally: finally:
return models return models

@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected). # protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online-onnx'] engine_list: ['asr_online-inference']
################################################################################# #################################################################################
@ -22,8 +22,8 @@ engine_list: ['asr_online-onnx']
################### speech task: asr; engine_type: online-inference ####################### ################### speech task: asr; engine_type: online-inference #######################
asr_online-inference: asr_online-inference:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
@ -54,7 +54,7 @@ asr_online-inference:
################### speech task: asr; engine_type: online-onnx ####################### ################### speech task: asr; engine_type: online-onnx #######################
asr_online-onnx: asr_online-onnx:
model_type: 'deepspeech2online_aishell' model_type: 'deepspeech2online_aishell'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of onnx am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
@ -81,4 +81,4 @@ asr_online-onnx:
window_n: 7 # frame window_n: 7 # frame
shift_n: 4 # frame shift_n: 4 # frame
window_ms: 20 # ms window_ms: 20 # ms
shift_ms: 10 # ms shift_ms: 10 # ms

@ -331,6 +331,13 @@ class PaddleASRConnectionHanddler:
else: else:
return '' return ''
def get_word_time_stamp(self):
return []
@paddle.no_grad()
def rescoring(self):
...
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
@ -409,17 +416,18 @@ class ASRServerExecutor(ASRExecutor):
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[ self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
'model']) if am_model is None else os.path.abspath(am_model) 'onnx_model']) if am_model is None else os.path.abspath(am_model)
self.am_params = os.path.join(
self.res_path, self.task_resource.res_dict[ # self.am_params = os.path.join(
'params']) if am_params is None else os.path.abspath(am_params) # self.res_path, self.task_resource.res_dict[
# 'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:") logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}") logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}") logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}") logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}") logger.info(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}") # logger.info(f" am_params path: {self.am_params}")
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)

@ -345,6 +345,12 @@ class PaddleASRConnectionHanddler:
else: else:
return '' return ''
def get_word_time_stamp(self):
return []
@paddle.no_grad()
def rescoring(self):
...
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):

Loading…
Cancel
Save