code format

pull/2036/head
Hui Zhang 2 years ago
parent 42d28b961c
commit 9106daa2a3

@ -11,7 +11,7 @@ port: 8090
# protocol = ['websocket'] (only one can be selected). # protocol = ['websocket'] (only one can be selected).
# websocket only support online engine type. # websocket only support online engine type.
protocol: 'websocket' protocol: 'websocket'
engine_list: ['asr_online-onnx'] engine_list: ['asr_online-inference']
################################################################################# #################################################################################

@ -164,9 +164,11 @@ class CommonTaskResource:
try: try:
import_models = '{}_{}_pretrained_models'.format(self.task, import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format) self.model_format)
print(f"from .pretrained_models import {import_models}")
exec('from .pretrained_models import {}'.format(import_models)) exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models]) models = OrderedDict(locals()[import_models])
except ImportError: except Exception as e:
print(e)
models = OrderedDict({}) # no models. models = OrderedDict({}) # no models.
finally: finally:
return models return models

@ -306,12 +306,13 @@ class PaddleASRConnectionHanddler:
assert (len(input_names) == len(output_names)) assert (len(input_names) == len(output_names))
assert isinstance(input_names[0], str) assert isinstance(input_names[0], str)
input_datas = [self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens, x_chunk] input_datas = [
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
x_chunk
]
feeds = dict(zip(input_names, input_datas)) feeds = dict(zip(input_names, input_datas))
outputs = self.am_predictor.run( outputs = self.am_predictor.run([*output_names], {**feeds})
[*output_names],
{**feeds})
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
@ -335,7 +336,7 @@ class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.task_resource = CommonTaskResource( self.task_resource = CommonTaskResource(
task='asr', model_format='static', inference_mode='online') task='asr', model_format='onnx', inference_mode='online')
def update_config(self) -> None: def update_config(self) -> None:
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
@ -407,10 +408,11 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
self.am_model = os.path.join(self.res_path, self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
self.task_resource.res_dict['model']) if am_model is None else os.path.abspath(am_model) 'model']) if am_model is None else os.path.abspath(am_model)
self.am_params = os.path.join(self.res_path, self.am_params = os.path.join(
self.task_resource.res_dict['params']) if am_params is None else os.path.abspath(am_params) self.res_path, self.task_resource.res_dict[
'params']) if am_params is None else os.path.abspath(am_params)
logger.info("Load the pretrained model:") logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}") logger.info(f" tag = {tag}")

@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
from ..utils.log import logger from ..utils.log import logger
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']
class EngineFactory(object): class EngineFactory(object):
@staticmethod @staticmethod
def get_engine(engine_name: Text, engine_type: Text): def get_engine(engine_name: Text, engine_type: Text):
logger.info(f"{engine_name} : {engine_type} engine.") logger.info(f"{engine_name} : {engine_type} engine.")
if engine_name == 'asr' and engine_type == 'inference': if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine() return ASREngine()

@ -39,7 +39,8 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
if 'cpu_threads' in sess_conf: if 'cpu_threads' in sess_conf:
sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0) sess_options.intra_op_num_threads = sess_conf.get("cpu_threads", 0)
else: else:
sess_options.intra_op_num_threads = sess_conf.get("intra_op_num_threads", 0) sess_options.intra_op_num_threads = sess_conf.get(
"intra_op_num_threads", 0)
sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0) sess_options.inter_op_num_threads = sess_conf.get("inter_op_num_threads", 0)

@ -27,7 +27,8 @@ def parse_args():
'--input_file', '--input_file',
type=str, type=str,
default="static_ds2online_inputs.pickle", default="static_ds2online_inputs.pickle",
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", ) help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
)
parser.add_argument( parser.add_argument(
'--model_type', '--model_type',
type=str, type=str,
@ -57,7 +58,6 @@ if __name__ == '__main__':
iodict = pickle.load(f) iodict = pickle.load(f)
print(iodict.keys()) print(iodict.keys())
audio_chunk = iodict['audio_chunk'] audio_chunk = iodict['audio_chunk']
audio_chunk_lens = iodict['audio_chunk_lens'] audio_chunk_lens = iodict['audio_chunk_lens']
chunk_state_h_box = iodict['chunk_state_h_box'] chunk_state_h_box = iodict['chunk_state_h_box']

Loading…
Cancel
Save