Add TTS fastdeploy serving (#2528)

* add triton tts server

* change readme

* fix path bug

* fix code style

* fix code style and readme

* Add files via upload
pull/2552/head
Thomas Young 2 years ago committed by GitHub
parent 2d71577e75
commit bf6451ed69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,67 @@
([简体中文](./README_cn.md)|English)
# Streaming Speech Synthesis Service
## Introduction
This demo is an implementation of starting the streaming speech synthesis service and accessing the service.
`Server` must be started in the docker, while `Client` does not have to be in the docker.
**The streaming_tts_serving under the path of this article ($PWD) contains the configuration and code of the model, which needs to be mapped to the docker for use.**
## Usage
### 1. Server
#### 1.1 Docker
```bash
docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker exec -it -u root fastdeploy bash
```
#### 1.2 Installation(inside the docker)
```bash
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
pip3 install paddlespeech
export LC_ALL="zh_CN.UTF-8"
export LANG="zh_CN.UTF-8"
export LANGUAGE="zh_CN:zh:en_US:en"
```
#### 1.3 Download models(inside the docker)
```bash
cd /models/streaming_tts_serving/1
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
unzip mb_melgan_csmsc_onnx_0.2.0.zip
```
**For the convenience of users, we recommend that you use the command `docker -v` to map $PWD (streaming_tts_service and the configuration and code of the model contained therein) to the docker path `/models`. You can also use other methods, but regardless of which method you use, the final model directory and structure in the docker are shown in the following figure.**
<p align="center">
<img src="./tree.png" />
</p>
#### 1.4 Start the server(inside the docker)
```bash
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving
```
Arguments:
- `model-repository`(required): Path of model storage.
- `model-control-mode`(required): The mode of loading the model. At present, you can use 'explicit'.
- `load-model`(required): Name of the model to be loaded.
- `http-port`(optional): Port for http service. Default: `8000`. This is not used in our example.
- `grpc-port`(optional): Port for grpc service. Default: `8001`.
- `metrics-port`(optional): Port for metrics service. Default: `8002`. This is not used in our example.
### 2. Client
#### 2.1 Installation
```bash
pip3 install tritonclient[all]
```
#### 2.2 Send request
```bash
python3 /models/streaming_tts_serving/stream_client.py
```

@ -0,0 +1,67 @@
(简体中文|[English](./README.md))
# 流式语音合成服务
## 介绍
本文介绍了使用FastDeploy搭建流式语音合成服务的方法。
`服务端`必须在docker内启动,而`客户端`不是必须在docker容器内.
**本文所在路径`($PWD)下的streaming_tts_serving里包含模型的配置和代码`(服务端会加载模型和代码以启动服务),需要将其映射到docker中使用。**
## 使用
### 1. 服务端
#### 1.1 Docker
```bash
docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker exec -it -u root fastdeploy bash
```
#### 1.2 安装(在docker内)
```bash
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
pip3 install paddlespeech
export LC_ALL="zh_CN.UTF-8"
export LANG="zh_CN.UTF-8"
export LANGUAGE="zh_CN:zh:en_US:en"
```
#### 1.3 下载模型(在docker内)
```bash
cd /models/streaming_tts_serving/1
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
unzip mb_melgan_csmsc_onnx_0.2.0.zip
```
**为了方便用户使用我们推荐用户使用1.1中的`docker -v`命令将`$PWD(streaming_tts_serving及里面包含的模型的配置和代码)映射到了docker内的/models路径`,用户也可以使用其他办法,但无论使用哪种方法,最终在docker内的模型目录及结构如下图所示。**
<p align="center">
<img src="./tree.png" />
</p>
#### 1.4 启动服务端(在docker内)
```bash
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving
```
参数:
- `model-repository`(required): 整套模型streaming_tts_serving存放的路径.
- `model-control-mode`(required): 模型加载的方式,现阶段, 使用'explicit'即可.
- `load-model`(required): 需要加载的模型的名称.
- `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口.
- `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`.
- `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口.
### 2. 客户端
#### 2.1 安装
```bash
pip3 install tritonclient[all]
```
#### 2.2 发送请求
```bash
python3 /models/streaming_tts_serving/stream_client.py
```

@ -0,0 +1,289 @@
import codecs
import json
import math
import sys
import threading
import time
import numpy as np
import onnxruntime as ort
import triton_python_backend_utils as pb_utils
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend.zh_frontend import Frontend
voc_block = 36
voc_pad = 14
am_block = 72
am_pad = 12
voc_upsample = 300
# 模型路径
dir_name = "/models/streaming_tts_serving/1/"
phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt"
am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy"
onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx"
onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx"
onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx"
onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx"
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
am_mu, am_std = np.load(am_stat_path)
# 用CPU推理
providers = ['CPUExecutionProvider']
# 配置ort session
sess_options = ort.SessionOptions()
# 创建session
am_encoder_infer_sess = ort.InferenceSession(
onnx_am_encoder, providers=providers, sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
onnx_am_decoder, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
onnx_am_postnet, providers=providers, sess_options=sess_options)
voc_melgan_sess = ort.InferenceSession(
onnx_voc_melgan, providers=providers, sess_options=sess_options)
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to intialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
print(sys.getdefaultencoding())
# You must parse model_config. JSON string is not parsed here
self.model_config = model_config = json.loads(args['model_config'])
print("model_config:", self.model_config)
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(args['model_name']))
self.input_names = []
for input_config in self.model_config["input"]:
self.input_names.append(input_config["name"])
print("input:", self.input_names)
self.output_names = []
self.output_dtype = []
for output_config in self.model_config["output"]:
self.output_names.append(output_config["name"])
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
self.output_dtype.append(dtype)
print("output:", self.output_names)
# To keep track of response threads so that we can delay
# the finalizing the model until all response threads
# have completed.
self.inflight_thread_count = 0
self.inflight_thread_count_lck = threading.Lock()
def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference is requested
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse.
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
# This model does not support batching, so 'request_count' should always
# be 1.
if len(requests) != 1:
raise pb_utils.TritonModelException("unsupported batch size " + len(
requests))
input_data = []
for idx in range(len(self.input_names)):
data = pb_utils.get_input_tensor_by_name(requests[0],
self.input_names[idx])
data = data.as_numpy()
data = data[0].decode('utf-8')
input_data.append(data)
text = input_data[0]
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(
target=self.response_thread,
args=(requests[0].get_response_sender(), text))
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
thread.start()
# Unlike in non-decoupled model transaction policy, execute function
# here returns no response. A return from this function only notifies
# Triton that the model instance is ready to receive another request. As
# we are not waiting for the response thread to complete here, it is
# possible that at any give time the model may be processing multiple
# requests. Depending upon the request workload, this may lead to a lot
# of requests being processed by a single model instance at a time. In
# real-world models, the developer should be mindful of when to return
# from execute and be willing to accept next request.
return None
def response_thread(self, response_sender, text):
input_ids = frontend.get_input_ids(
text, merge_sentences=False, get_tone_ids=False)
phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i].numpy()
voc_chunk_id = 0
orig_hs = am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / voc_block)
start = 0
end = min(voc_block + voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, am_block, am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
})
am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, am_mu, am_std)
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad,
1)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
voc_chunk = mel_streaming[start:end, :]
sub_wav = voc_melgan_sess.run(
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
voc_block, voc_pad, voc_upsample)
output_np = np.array(sub_wav, dtype=self.output_dtype[0])
out_tensor1 = pb_utils.Tensor(self.output_names[0],
output_np)
status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1
output_status = np.array(
[status], dtype=self.output_dtype[1])
out_tensor2 = pb_utils.Tensor(self.output_names[1],
output_status)
inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor1, out_tensor2])
#yield sub_wav
response_sender.send(inference_response)
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len)
# We must close the response sender to indicate to Triton that we are
# done sending responses for the corresponding request. We can't use the
# response sender after closing it. The response sender is closed by
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
response_sender.send(
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1
def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is OPTIONAL. This function allows
the model to perform any necessary clean ups before exit.
Here we will wait for all response threads to complete sending
responses.
"""
print('Finalize invoked')
inflight_threads = True
cycles = 0
logging_time_sec = 5
sleep_time_sec = 0.1
cycle_to_log = (logging_time_sec / sleep_time_sec)
while inflight_threads:
with self.inflight_thread_count_lck:
inflight_threads = (self.inflight_thread_count != 0)
if (cycles % cycle_to_log == 0):
print(
f"Waiting for {self.inflight_thread_count} response threads to complete..."
)
if inflight_threads:
time.sleep(sleep_time_sec)
cycles += 1
print('Finalize complete...')

@ -0,0 +1,33 @@
name: "streaming_tts_serving"
backend: "python"
max_batch_size: 0
model_transaction_policy {
decoupled: True
}
input [
{
name: "INPUT_0"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "OUTPUT_0"
data_type: TYPE_FP32
dims: [ -1, 1 ]
},
{
name: "status"
data_type: TYPE_BOOL
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

@ -0,0 +1,117 @@
#!/usr/bin/env python
import argparse
import queue
import sys
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import *
FLAGS = None
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def async_stream_send(triton_client, values, request_id, model_name):
infer_inputs = []
outputs = []
for idx, data in enumerate(values):
data = np.array([data.encode('utf-8')], dtype=np.object_)
infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
# Issue the asynchronous sequence inference.
triton_client.async_stream_infer(
model_name=model_name,
inputs=infer_inputs,
outputs=outputs,
request_id=request_id)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument(
'-u',
'--url',
type=str,
required=False,
default='localhost:8001',
help='Inference server URL and it gRPC port. Default is localhost:8001.')
FLAGS = parser.parse_args()
# We use custom "sequence" models which take 1 input
# value. The output is the accumulated value of the inputs. See
# src/custom/sequence.
model_name = "streaming_tts_serving"
values = ["哈哈哈哈"]
request_id = "0"
string_result0_list = []
user_data = UserData()
# It is advisable to use client object within with..as clause
# when sending streaming requests. This ensures the client
# is closed when the block inside with exits.
with grpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose) as triton_client:
try:
# Establish stream
triton_client.start_stream(callback=partial(callback, user_data))
# Now send the inference sequences...
async_stream_send(triton_client, values, request_id, model_name)
except InferenceServerException as error:
print(error)
sys.exit(1)
# Retrieve results...
recv_count = 0
result_dict = {}
status = True
while True:
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
else:
this_id = data_item.get_response().id
if this_id not in result_dict.keys():
result_dict[this_id] = []
result_dict[this_id].append((recv_count, data_item))
sub_wav = data_item.as_numpy('OUTPUT_0')
status = data_item.as_numpy('status')
print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
print('status = ', status)
if status[0] == 1:
break
recv_count += 1
print("PASS: stream_client")

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Loading…
Cancel
Save