diff --git a/demos/streaming_asr_server/conf/ws_ds2_application.yaml b/demos/streaming_asr_server/conf/ws_ds2_application.yaml index f0a98e72..f67d3157 100644 --- a/demos/streaming_asr_server/conf/ws_ds2_application.yaml +++ b/demos/streaming_asr_server/conf/ws_ds2_application.yaml @@ -11,7 +11,7 @@ port: 8090 # protocol = ['websocket'] (only one can be selected). # websocket only support online engine type. protocol: 'websocket' -engine_list: ['asr_online-onnx'] +engine_list: ['asr_online-inference'] ################################################################################# diff --git a/paddlespeech/resource/resource.py b/paddlespeech/resource/resource.py index 369dba90..2e637f0f 100644 --- a/paddlespeech/resource/resource.py +++ b/paddlespeech/resource/resource.py @@ -164,9 +164,11 @@ class CommonTaskResource: try: import_models = '{}_{}_pretrained_models'.format(self.task, self.model_format) + print(f"from .pretrained_models import {import_models}") exec('from .pretrained_models import {}'.format(import_models)) models = OrderedDict(locals()[import_models]) - except ImportError: + except Exception as e: + print(e) models = OrderedDict({}) # no models. finally: return models diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index 0bd2f950..97addc7a 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler: assert (len(input_names) == len(output_names)) assert isinstance(input_names[0], str) - input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk] + input_datas = [ + self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, + x_chunk + ] feeds = dict(zip(input_names, input_datas)) - outputs = self.am_predictor.run( - [*output_names], - {**feeds}) + outputs = self.am_predictor.run([*output_names], {**feeds}) output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs self.decoder.next(output_chunk_probs, output_chunk_lens) @@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() self.task_resource = CommonTaskResource( - task='asr', model_format='static', inference_mode='online') + task='asr', model_format='onnx', inference_mode='online') def update_config(self) -> None: if "deepspeech2" in self.model_type: @@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - self.am_model = os.path.join(self.res_path, - self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model) - self.am_params = os.path.join(self.res_path, - self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params) + self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[ + 'model']) if am_model is None else os.path.abspath(am_model) + self.am_params = os.path.join( + self.res_path, self.task_resource.res_dict[ + 'params']) if am_params is None else os.path.abspath(am_params) logger.info("Load the pretrained model:") logger.info(f" tag = {tag}") diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py index 3c1c3d53..6a66a002 100644 --- a/paddlespeech/server/engine/engine_factory.py +++ b/paddlespeech/server/engine/engine_factory.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Text + from ..utils.log import logger __all__ = ['EngineFactory'] + class EngineFactory(object): @staticmethod def get_engine(engine_name: Text, engine_type: Text): logger.info(f"{engine_name} : {engine_type} engine.") + if engine_name == 'asr' and engine_type == 'inference': from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine return ASREngine() diff --git a/paddlespeech/server/utils/onnx_infer.py b/paddlespeech/server/utils/onnx_infer.py index 4287477f..1c9d878f 100644 --- a/paddlespeech/server/utils/onnx_infer.py +++ b/paddlespeech/server/utils/onnx_infer.py @@ -35,14 +35,15 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): if sess_conf.get("use_trt", 0): providers = ['TensorrtExecutionProvider'] logger.info(f"ort providers: {providers}") - + if 'cpu_threads' in sess_conf: - sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) + sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) else: - sess_options.intra_op_num_threads = sess_conf.get("intra_op_num_threads", 0) + sess_options.intra_op_num_threads = sess_conf.get( + "intra_op_num_threads", 0) sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0) - + sess = ort.InferenceSession( model_path, providers=providers, sess_options=sess_options) return sess diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py index a5ec7ce3..f821baa1 100755 --- a/speechx/examples/ds2_ol/onnx/local/infer_check.py +++ b/speechx/examples/ds2_ol/onnx/local/infer_check.py @@ -27,7 +27,8 @@ def parse_args(): '--input_file', type=str, default="static_ds2online_inputs.pickle", - help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", ) + help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", + ) parser.add_argument( '--model_type', type=str, @@ -57,7 +58,6 @@ if __name__ == '__main__': iodict = pickle.load(f) print(iodict.keys()) - audio_chunk = iodict['audio_chunk'] audio_chunk_lens = iodict['audio_chunk_lens'] chunk_state_h_box = iodict['chunk_state_h_box']