Merge branch 'develop' of https://github.com/lym0302/PaddleSpeech into paddlespeech_stats

pull/1497/head
lym0302 3 years ago
commit fe6be4a65e

@ -9,9 +9,17 @@ port: 8090
################################################################## ##################################################################
# CONFIG FILE # # CONFIG FILE #
################################################################## ##################################################################
# add engine type (Options: asr, tts) and config file here. # The engine_type of speech task needs to keep the same type as the config file of speech task.
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
#
# add engine type (Options: python, inference)
engine_type:
asr: 'inference'
tts: 'inference'
# add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend: engine_backend:
asr: 'conf/asr/asr.yaml' asr: 'conf/asr/asr_pd.yaml'
tts: 'conf/tts/tts.yaml' tts: 'conf/tts/tts_pd.yaml'

@ -1,7 +1,8 @@
model: 'conformer_wenetspeech' model: 'conformer_wenetspeech'
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path: # [optional]
ckpt_path: ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: False force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu'

@ -0,0 +1,25 @@
# This is the parameter configuration file for ASR server.
# These are the static models that support paddle inference.
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['deepspeech2offline_aishell'] TODO
##################################################################
model_type: 'deepspeech2offline_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True
##################################################################
# OTHERS #
##################################################################

@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: 'gpu:2' device: 'gpu:3' # set 'gpu:id' or 'cpu'

@ -6,8 +6,8 @@
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'] # am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
################################################################## ##################################################################
am: 'fastspeech2_csmsc' am: 'fastspeech2_csmsc'
am_model: # the pdmodel file of am static model am_model: # the pdmodel file of your am static model (XX.pdmodel)
am_params: # the pdiparams file of am static model am_params: # the pdiparams file of your am static model (XX.pdipparams)
am_sample_rate: 24000 am_sample_rate: 24000
phones_dict: phones_dict:
tones_dict: tones_dict:
@ -15,9 +15,9 @@ speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
use_gpu: True device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
@ -25,17 +25,16 @@ am_predictor_conf:
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc'] # voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
################################################################## ##################################################################
voc: 'pwgan_csmsc' voc: 'pwgan_csmsc'
voc_model: # the pdmodel file of vocoder static model voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
voc_params: # the pdiparams file of vocoder static model voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000
voc_predictor_conf: voc_predictor_conf:
use_gpu: True device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device()

@ -11,3 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.

@ -291,6 +291,7 @@ class ASRExecutor(BaseExecutor):
""" """
audio_file = input audio_file = input
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocess audio_file:" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction # Get the object for feature extraction
@ -412,13 +413,13 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000") logger.error("invalid sample rate, please input --sr 8000 or --sr 16000")
raise Exception("invalid sample rate") return False
sys.exit(-1)
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file): if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path") logger.error("Please input the right audio file path")
sys.exit(-1) return False
logger.info("checking the audio file format......") logger.info("checking the audio file format......")
try: try:
@ -435,7 +436,7 @@ class ASRExecutor(BaseExecutor):
sample rate: 8k \n \ sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \ sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
") ")
sys.exit(-1) return False
logger.info("The sample rate is %d" % audio_sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \ logger.warning("The sample rate of the input file is not {}.\n \
@ -469,6 +470,8 @@ class ASRExecutor(BaseExecutor):
logger.info("The audio file format is right") logger.info("The audio file format is right")
self.change_format = False self.change_format = False
return True
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
@ -523,7 +526,8 @@ class ASRExecutor(BaseExecutor):
Python API to call an executor. Python API to call an executor.
""" """
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
self._check(audio_file, sample_rate, force_yes) if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
paddle.set_device(device) paddle.set_device(device)
self._init_from_path(model, lang, sample_rate, config, decode_method, self._init_from_path(model, lang, sample_rate, config, decode_method,
ckpt_path) ckpt_path)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List from typing import List
from io import BytesIO
import numpy as np import numpy as np
@ -88,6 +89,10 @@ def pad_sequence(sequences: List[np.ndarray],
def feat_type(filepath): def feat_type(filepath):
# deal with Byteio type for paddlespeech server
if isinstance(filepath, BytesIO):
return 'sound'
suffix = filepath.split(":")[0].split('.')[-1].lower() suffix = filepath.split(":")[0].split('.')[-1].lower()
if suffix == 'ark': if suffix == 'ark':
return 'mat' return 'mat'

@ -16,10 +16,9 @@ import uvicorn
import yaml import yaml
from fastapi import FastAPI from fastapi import FastAPI
from paddlespeech.server.engine.engine_factory import EngineFactory from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.log import logger
app = FastAPI( app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1") title="PaddleSpeech Serving API", description="Api", version="0.0.1")
@ -39,11 +38,7 @@ def init(config):
api_router = setup_router(api_list) api_router = setup_router(api_list)
app.include_router(api_router) app.include_router(api_router)
# init engine if not init_engine_pool(config):
engine_pool = []
for engine in config.engine_backend:
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
if not engine_pool[-1].init(config_file=config.engine_backend[engine]):
return False return False
return True return True

@ -20,7 +20,7 @@ from fastapi import FastAPI
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..util import cli_server_register from ..util import cli_server_register
from ..util import stats_wrapper from ..util import stats_wrapper
from paddlespeech.server.engine.engine_factory import EngineFactory from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router from paddlespeech.server.restful.api import setup_router
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
@ -63,12 +63,7 @@ class ServerExecutor(BaseExecutor):
api_router = setup_router(api_list) api_router = setup_router(api_list)
app.include_router(api_router) app.include_router(api_router)
# init engine if not init_engine_pool(config):
engine_pool = []
for engine in config.engine_backend:
engine_pool.append(EngineFactory.get_engine(engine_name=engine))
if not engine_pool[-1].init(
config_file=config.engine_backend[engine]):
return False return False
return True return True

@ -9,9 +9,17 @@ port: 8090
################################################################## ##################################################################
# CONFIG FILE # # CONFIG FILE #
################################################################## ##################################################################
# add engine type (Options: asr, tts) and config file here. # The engine_type of speech task needs to keep the same type as the config file of speech task.
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
#
# add engine type (Options: python, inference)
engine_type:
asr: 'python'
tts: 'python'
# add engine backend type (Options: asr, tts) and config file here.
# Adding a speech task to engine_backend means starting the service.
engine_backend: engine_backend:
asr: 'conf/asr/asr.yaml' asr: 'conf/asr/asr.yaml'
tts: 'conf/tts/tts_pd.yaml' tts: 'conf/tts/tts.yaml'

@ -1,7 +1,8 @@
model: 'conformer_wenetspeech' model: 'conformer_wenetspeech'
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path: # [optional]
ckpt_path: ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
force_yes: False force_yes: True
device: 'gpu:3' # set 'gpu:id' or 'cpu'

@ -0,0 +1,25 @@
# This is the parameter configuration file for ASR server.
# These are the static models that support paddle inference.
##################################################################
# ACOUSTIC MODEL SETTING #
# am choices=['deepspeech2offline_aishell'] TODO
##################################################################
model_type: 'deepspeech2offline_aishell'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True
switch_ir_optim: True
##################################################################
# OTHERS #
##################################################################

@ -29,4 +29,4 @@ voc_stat:
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device() device: 'gpu:3' # set 'gpu:id' or 'cpu'

@ -6,18 +6,18 @@
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'] # am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
################################################################## ##################################################################
am: 'fastspeech2_csmsc' am: 'fastspeech2_csmsc'
am_model: # the pdmodel file of am static model am_model: # the pdmodel file of your am static model (XX.pdmodel)
am_params: # the pdiparams file of am static model am_params: # the pdiparams file of your am static model (XX.pdipparams)
am_sample_rate: 24000 am_sample_rate: 24000 # must match the model
phones_dict: phones_dict:
tones_dict: tones_dict:
speaker_dict: speaker_dict:
spk_id: 0 spk_id: 0
am_predictor_conf: am_predictor_conf:
use_gpu: True device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
@ -25,17 +25,16 @@ am_predictor_conf:
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc'] # voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
################################################################## ##################################################################
voc: 'pwgan_csmsc' voc: 'pwgan_csmsc'
voc_model: # the pdmodel file of vocoder static model voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
voc_params: # the pdiparams file of vocoder static model voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
voc_sample_rate: 24000 voc_sample_rate: 24000 #must match the model
voc_predictor_conf: voc_predictor_conf:
use_gpu: True device: 'gpu:3' # set 'gpu:id' or 'cpu'
enable_mkldnn: True enable_mkldnn: False
switch_ir_optim: True switch_ir_optim: False
################################################################## ##################################################################
# OTHERS # # OTHERS #
################################################################## ##################################################################
lang: 'zh' lang: 'zh'
device: paddle.get_device()

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,244 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
from paddlespeech.server.engine.base_engine import BaseEngine
__all__ = ['ASREngine']
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
pass
def _init_from_path(self,
model_type: str='wenetspeech',
am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None,
lang: str='zh',
sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None,
decode_method: str='attention_rescoring',
am_predictor_conf: dict=None):
"""
Init model and other resources from a specific path.
"""
if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path,
pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
pretrained_models[tag]['params'])
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
with UpdateConfig(self.config):
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
raise Exception("wrong type")
else:
raise Exception("wrong type")
# AM predictor
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
# decoder
self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2,
blank_id=self.config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
decode_batch_size = audio.shape[0]
# init once
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
output_data = run_model(
self.am_predictor,
[audio.numpy(), audio_len.numpy()])
probs = output_data[0]
eouts_len = output_data[1]
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
# self.model.decoder.del_decoder()
self._outputs["result"] = trans_best[0]
elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name")
else:
raise Exception("invalid model name")
class ASREngine(BaseEngine):
"""ASR server engine
Args:
metaclass: Defaults to Singleton.
"""
def __init__(self):
super(ASREngine, self).__init__()
def init(self, config_file: str) -> bool:
"""init engine resource
Args:
config_file (str): config file
Returns:
bool: init failed or success
"""
self.input = None
self.output = None
self.executor = ASRServerExecutor()
self.config = get_config(config_file)
paddle.set_device(paddle.get_device())
self.executor._init_from_path(
model_type=self.config.model_type,
am_model=self.config.am_model,
am_params=self.config.am_params,
lang=self.config.lang,
sample_rate=self.config.sample_rate,
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
am_predictor_conf=self.config.am_predictor_conf)
logger.info("Initialize ASR server engine successfully.")
return True
def run(self, audio_data):
"""engine run
Args:
audio_data (bytes): base64.b64decode
"""
if self.executor._check(
io.BytesIO(audio_data), self.config.sample_rate,
self.config.force_yes):
logger.info("start running asr engine")
self.executor.preprocess(self.config.model_type, io.BytesIO(audio_data))
self.executor.infer(self.config.model_type)
self.output = self.executor.postprocess() # Retrieve result of asr.
logger.info("end inferring asr engine")
else:
logger.info("file check failed!")
self.output = None
def postprocess(self):
"""postprocess
"""
return self.output

@ -12,21 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import os
from typing import List
from typing import Optional
from typing import Union
import librosa
import paddle import paddle
import soundfile
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.config import get_config from paddlespeech.server.utils.config import get_config
@ -38,101 +28,6 @@ class ASRServerExecutor(ASRExecutor):
super().__init__() super().__init__()
pass pass
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error("please input --sr 8000 or --sr 16000")
return False
logger.info("checking the audio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.exception(e)
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
self.change_format = True
else:
logger.info("The audio file format is right")
self.change_format = False
return True
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
audio_file = input
# Get the object for feature extraction
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
# vocab_list = collate_fn_test.vocab_list
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
logger.info("get the preprocess conf")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if self.change_format:
if audio.shape[1] >= 2:
audio = audio.mean(axis=1, dtype=np.int16)
else:
audio = audio[:, 0]
# pcm16 -> pcm 32
audio = self._pcm16to32(audio)
audio = librosa.resample(audio, audio_sample_rate,
self.sample_rate)
audio_sample_rate = self.sample_rate
# pcm32 -> pcm 16
audio = self._pcm32to16(audio)
else:
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
audio = preprocessing(audio, **preprocess_args)
audio_len = paddle.to_tensor(audio.shape[0])
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
else:
raise Exception("wrong type")
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
"""ASR server engine """ASR server engine
@ -157,16 +52,12 @@ class ASREngine(BaseEngine):
self.output = None self.output = None
self.executor = ASRServerExecutor() self.executor = ASRServerExecutor()
try:
self.config = get_config(config_file) self.config = get_config(config_file)
paddle.set_device(paddle.get_device()) paddle.set_device(self.config.device)
self.executor._init_from_path( self.executor._init_from_path(
self.config.model, self.config.lang, self.config.sample_rate, self.config.model, self.config.lang, self.config.sample_rate,
self.config.cfg_path, self.config.decode_method, self.config.cfg_path, self.config.decode_method,
self.config.ckpt_path) self.config.ckpt_path)
except:
logger.info("Initialize ASR server engine Failed.")
return False
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True

@ -13,20 +13,24 @@
# limitations under the License. # limitations under the License.
from typing import Text from typing import Text
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
#from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
__all__ = ['EngineFactory'] __all__ = ['EngineFactory']
class EngineFactory(object): class EngineFactory(object):
@staticmethod @staticmethod
def get_engine(engine_name: Text): def get_engine(engine_name: Text, engine_type: Text):
if engine_name == 'asr': if engine_name == 'asr' and engine_type == 'inference':
from paddlespeech.server.engine.asr.paddleinference.asr_engine import ASREngine
return ASREngine()
elif engine_name == 'asr' and engine_type == 'python':
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine
return ASREngine() return ASREngine()
elif engine_name == 'tts': elif engine_name == 'tts' and engine_type == 'inference':
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine
return TTSEngine()
elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine() return TTSEngine()
else: else:
return None return None

@ -0,0 +1,36 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlespeech.server.engine.engine_factory import EngineFactory
# global value
ENGINE_POOL = {}
def get_engine_pool() -> dict:
""" Get engine pool
"""
global ENGINE_POOL
return ENGINE_POOL
def init_engine_pool(config) -> bool:
""" Init engine pool
"""
global ENGINE_POOL
for engine in config.engine_backend:
ENGINE_POOL[engine] = EngineFactory.get_engine(engine_name=engine, engine_type=config.engine_type[engine])
if not ENGINE_POOL[engine].init(config_file=config.engine_backend[engine]):
return False
return True

@ -344,7 +344,6 @@ class TTSEngine(BaseEngine):
try: try:
self.config = get_config(config_file) self.config = get_config(config_file)
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
am_model=self.config.am_model, am_model=self.config.am_model,

@ -22,7 +22,14 @@ _router = APIRouter()
def setup_router(api_list: List): def setup_router(api_list: List):
"""setup router for fastapi
Args:
api_list (List): [asr, tts]
Returns:
APIRouter
"""
for api_name in api_list: for api_name in api_list:
if api_name == 'asr': if api_name == 'asr':
_router.include_router(asr_router) _router.include_router(asr_router)

@ -16,7 +16,7 @@ import traceback
from typing import Union from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.engine.asr.python.asr_engine import ASREngine from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import ASRRequest from paddlespeech.server.restful.request import ASRRequest
from paddlespeech.server.restful.response import ASRResponse from paddlespeech.server.restful.response import ASRResponse
from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.restful.response import ErrorResponse
@ -61,9 +61,12 @@ def asr(request_body: ASRRequest):
json: [description] json: [description]
""" """
try: try:
# single
audio_data = base64.b64decode(request_body.audio) audio_data = base64.b64decode(request_body.audio)
asr_engine = ASREngine()
# get single engine from engine pool
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_engine.run(audio_data) asr_engine.run(audio_data)
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()

@ -16,7 +16,7 @@ from typing import Union
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.engine.tts.paddleinference.tts_engine import TTSEngine from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.restful.request import TTSRequest from paddlespeech.server.restful.request import TTSRequest
from paddlespeech.server.restful.response import ErrorResponse from paddlespeech.server.restful.response import ErrorResponse
from paddlespeech.server.restful.response import TTSResponse from paddlespeech.server.restful.response import TTSResponse
@ -60,28 +60,41 @@ def tts(request_body: TTSRequest):
Returns: Returns:
json: [description] json: [description]
""" """
# json to dict # get params
item_dict = request_body.dict() text = request_body.text
sentence = item_dict['text'] spk_id = request_body.spk_id
spk_id = item_dict['spk_id'] speed = request_body.speed
speed = item_dict['speed'] volume = request_body.volume
volume = item_dict['volume'] sample_rate = request_body.sample_rate
sample_rate = item_dict['sample_rate'] save_path = request_body.save_path
save_path = item_dict['save_path']
# Check parameters # Check parameters
if speed <=0 or speed > 3 or volume <=0 or volume > 3 or \ if speed <= 0 or speed > 3:
sample_rate not in [0, 16000, 8000] or \ return failed_response(
(save_path is not None and not save_path.endswith("pcm") and not save_path.endswith("wav")): ErrorCode.SERVER_PARAM_ERR,
return failed_response(ErrorCode.SERVER_PARAM_ERR) "invalid speed value, the value should be between 0 and 3.")
if volume <= 0 or volume > 3:
# single return failed_response(
tts_engine = TTSEngine() ErrorCode.SERVER_PARAM_ERR,
"invalid volume value, the value should be between 0 and 3.")
if sample_rate not in [0, 16000, 8000]:
return failed_response(
ErrorCode.SERVER_PARAM_ERR,
"invalid sample_rate value, the choice of value is 0, 8000, 16000.")
if save_path is not None and not save_path.endswith(
"pcm") and not save_path.endswith("wav"):
return failed_response(
ErrorCode.SERVER_PARAM_ERR,
"invalid save_path, saved audio formats support pcm and wav")
# run # run
try: try:
# get single engine from engine pool
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
lang, target_sample_rate, wav_base64 = tts_engine.run( lang, target_sample_rate, wav_base64 = tts_engine.run(
sentence, spk_id, speed, volume, sample_rate, save_path) text, spk_id, speed, volume, sample_rate, save_path)
response = { response = {
"success": True, "success": True,

@ -41,8 +41,9 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
config = Config(model_file, params_file) config = Config(model_file, params_file)
config.enable_memory_optim() config.enable_memory_optim()
if predictor_conf["use_gpu"]: if "gpu" in predictor_conf["device"]:
config.enable_use_gpu(1000, 0) gpu_id = predictor_conf["device"].split(":")[-1]
config.enable_use_gpu(1000, int(gpu_id))
if predictor_conf["enable_mkldnn"]: if predictor_conf["enable_mkldnn"]:
config.enable_mkldnn() config.enable_mkldnn()
if predictor_conf["switch_ir_optim"]: if predictor_conf["switch_ir_optim"]:

Loading…
Cancel
Save