code format

pull/2036/head
Hui Zhang 3 years ago
parent 42d28b961c
commit 9106daa2a3

@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online-onnx']
engine_list: ['asr_online-inference']
#################################################################################

@ -164,9 +164,11 @@ class CommonTaskResource:
try:
import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format)
print(f"from .pretrained_models import {import_models}")
exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models])
except ImportError:
except Exception as e:
print(e)
models = OrderedDict({}) # no models.
finally:
return models

@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler:
assert (len(input_names) == len(output_names))
assert isinstance(input_names[0], str)
input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk]
input_datas = [
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
x_chunk
]
feeds = dict(zip(input_names, input_datas))
outputs = self.am_predictor.run(
[*output_names],
{**feeds})
outputs = self.am_predictor.run([*output_names], {**feeds})
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
self.decoder.next(output_chunk_probs, output_chunk_lens)
@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.task_resource = CommonTaskResource(
task='asr', model_format='static', inference_mode='online')
task='asr', model_format='onnx', inference_mode='online')
def update_config(self) -> None:
if "deepspeech2" in self.model_type:
@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
self.am_model = os.path.join(self.res_path,
self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model)
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params)
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
'model']) if am_model is None else os.path.abspath(am_model)
self.am_params = os.path.join(
self.res_path, self.task_resource.res_dict[
'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")

@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Text
from ..utils.log import logger
__all__ = ['EngineFactory']
class EngineFactory(object):
@staticmethod
def get_engine(engine_name: Text, engine_type: Text):
logger.info(f"{engine_name} : {engine_type} engine.")
if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine()

@ -35,14 +35,15 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider']
logger.info(f"ort providers: {providers}")
if 'cpu_threads' in sess_conf:
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
else:
sess_options.intra_op_num_threads = sess_conf.get("intra_op_num_threads", 0)
sess_options.intra_op_num_threads = sess_conf.get(
"intra_op_num_threads", 0)
sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0)
sess = ort.InferenceSession(
model_path, providers=providers, sess_options=sess_options)
return sess

@ -27,7 +27,8 @@ def parse_args():
'--input_file',
type=str,
default="static_ds2online_inputs.pickle",
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", )
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
)
parser.add_argument(
'--model_type',
type=str,
@ -57,7 +58,6 @@ if __name__ == '__main__':
iodict = pickle.load(f)
print(iodict.keys())
audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box']

Loading…
Cancel
Save