Merge pull request #1475 from WilliamZhang06/server_asr

[server] added engine type and asr inference
pull/1483/head
Hui Zhang 2 years ago committed by GitHub
commit fe350ddddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,3 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -291,7 +291,8 @@ class ASRExecutor(BaseExecutor):
"""
audio_file = input
logger.info("Preprocess audio_file:" + audio_file)
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
@ -412,13 +413,13 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000")
raise Exception("invalid sample rate")
sys.exit(-1)
logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
sys.exit(-1)
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
return False
logger.info("checking the audio file format......")
try:
@ -435,7 +436,7 @@ class ASRExecutor(BaseExecutor):
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
sys.exit(-1)
return False
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
@ -469,6 +470,8 @@ class ASRExecutor(BaseExecutor):
logger.info("The audio file format is right")
self.change_format = False
return True
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
@ -523,7 +526,8 @@ class ASRExecutor(BaseExecutor):
Python API to call an executor.
"""
audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate, force_yes)
if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from io import BytesIO
import numpy as np
@ -88,6 +89,10 @@ def pad_sequence(sequences: List[np.ndarray],
def feat_type(filepath):
# deal with Byteio type for paddlespeech server
if isinstance(filepath, BytesIO):
return 'sound'
suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark':
return 'mat'

@ -16,10 +16,9 @@ import uvicorn
import yaml
from fastapi import FastAPI
from paddlespeech.server.engine.engine_factory import EngineFactory
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.log import logger
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
@ -39,12 +38,8 @@ def init(config):
api_router = setup_router(api_list)
app.include_router(api_router)
# init engine
engine_pool = []
for engine in config.engine_backend:
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
if not engine_pool[-1].init(config_file=config.engine_backend[engine]):
return False
if not init_engine_pool(config):
return False
return True

@ -241,4 +241,4 @@ class ASRClientExecutor(BaseExecutor):
print(r.json())
print("time cost %f s." % (time_end - time_start))
except:
print("Failed to speech recognition.")
print("Failed to speech recognition.")

@ -51,10 +51,8 @@ class ServerExecutor(BaseExecutor):
def init(self, config) -> bool:
"""system initialization
Args:
config (CfgNode): config object
Returns:
bool:
"""

@ -9,9 +9,12 @@ port: 8090
##################################################################
# CONFIG FILE #
##################################################################
# add engine type (Options: asr, tts) and config file here.
# add engine type (Options: python, inference)
engine_type:
asr: 'inference'
# tts: 'inference'
# add engine backend type (Options: asr, tts) and config file here.
engine_backend:
asr: 'conf/asr/asr.yaml'
tts: 'conf/tts/tts_pd.yaml'
asr: 'conf/asr/asr_pd.yaml'
#tts: 'conf/tts/tts_pd.yaml'

@ -1,7 +1,7 @@
model: 'conformer_wenetspeech'
lang: 'zh'
sample_rate: 16000
cfg_path:
ckpt_path:
cfg_path: # [optional]
ckpt_path: # [optional]
decode_method: 'attention_rescoring'
force_yes: False
force_yes: True

@ -0,0 +1,25 @@
# This is the parameter configuration file for ASR server.
# These are the static models that support paddle inference.
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['deepspeech2offline_aishell'] TODO
##################################################################
model_type: 'deepspeech2offline_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
use_gpu: True
enable_mkldnn: True
switch_ir_optim: True
##################################################################
# OTHERS #
##################################################################

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,244 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
pass
def _init_from_path(self,
model_type: str='wenetspeech',
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
pretrained_models[tag]['params'])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type")
else:
raise Exception("wrong type")
# AM predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
# decoder
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
decode_batch_size = audio.shape[0]
# init once
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_data = run_model(
self.am_predictor,
[audio.numpy(), audio_len.numpy()])
probs = output_data[0]
eouts_len = output_data[1]
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
# self.model.decoder.del_decoder()
self._outputs["result"] = trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name")
else:
raise Exception("invalid model name")
class ASREngine(BaseEngine):
"""ASR server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init(self, config_file: str) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = get_config(config_file)
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data))
self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
else:
logger.info("file check failed!")
self.output = None
def postprocess(self):
"""postprocess
"""
return self.output

@ -38,101 +38,6 @@ class ASRServerExecutor(ASRExecutor):
super().__init__()
pass
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000")
return False
logger.info("checking the audio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.exception(e)
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
self.change_format = True
else:
logger.info("The audio file format is right")
self.change_format = False
return True
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
audio_file = input
# Get the object for feature extraction
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
# vocab_list = collate_fn_test.vocab_list
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
logger.info("get the preprocess conf")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if self.change_format:
if audio.shape[1] >= 2:
audio = audio.mean(axis=1, dtype=np.int16)
else:
audio = audio[:, 0]
# pcm16 -> pcm 32
audio = self._pcm16to32(audio)
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
# pcm32 -> pcm 16
audio = self._pcm32to16(audio)
else:
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
audio = preprocessing(audio, **preprocess_args)
audio_len = paddle.to_tensor(audio.shape[0])
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
else:
raise Exception("wrong type")
class ASREngine(BaseEngine):
"""ASR server engine
@ -157,16 +62,12 @@ class ASREngine(BaseEngine):
self.output = None
self.executor = ASRServerExecutor()
try:
self.config = get_config(config_file)
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method,
self.config.ckpt_path)
except:
logger.info("Initialize ASR server engine Failed.")
return False
self.config = get_config(config_file)
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method,
self.config.ckpt_path)
logger.info("Initialize ASR server engine successfully.")
return True

@ -13,20 +13,24 @@
# limitations under the License.
from typing import Text
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
__all__ = ['EngineFactory']
class EngineFactory(object):
@staticmethod
def get_engine(engine_name: Text):
if engine_name == 'asr':
def get_engine(engine_name: Text, engine_type: Text):
if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'python':
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'tts':
elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
return TTSEngine()
elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine()
else:
return None

@ -0,0 +1,36 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.server.engine.engine_factory import EngineFactory
# global value
ENGINE_POOL = {}
def get_engine_pool() -> dict:
""" Get engine pool
"""
global ENGINE_POOL
return ENGINE_POOL
def init_engine_pool(config) -> bool:
""" Init engine pool
"""
global ENGINE_POOL
for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
return False
return True

@ -22,7 +22,14 @@ _router = APIRouter()
def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
Returns:
APIRouter
"""
for api_name in api_list:
if api_name == 'asr':
_router.include_router(asr_router)

@ -16,7 +16,7 @@ import traceback
from typing import Union
from fastapi import APIRouter
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import ASRRequest
from paddlespeech.server.restful.response import ASRResponse
from paddlespeech.server.restful.response import ErrorResponse
@ -61,9 +61,12 @@ def asr(request_body: ASRRequest):
json: [description]
"""
try:
# single
audio_data = base64.b64decode(request_body.audio)
asr_engine = ASREngine()
# get single engine from engine pool
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_engine.run(audio_data)
asr_results = asr_engine.postprocess()

Loading…
Cancel
Save