add triton tts server

pull/2528/head
HexToString 3 years ago
parent c9b0c96b7b
commit 9e61b81f68

@ -0,0 +1,34 @@
#We assume your host model absolute path = $PWD and your docker model absolute path = /models
#For Server
#Start docker
docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
docker exec -it -u root fastdeploy bash
#Inside the docker
apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
pip install paddlespeech
export LC_ALL="zh_CN.UTF-8"
export LANG="zh_CN.UTF-8"
export LANGUAGE="zh_CN:zh:en_US:en"
#Download models
cd /models/streaming_tts_serving/1
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
unzip mb_melgan_csmsc_onnx_0.2.0.zip
#Start the server
fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving
#For Client
#Install
pip3 install tritonclient[all]
#Request
python3 /models/streaming_tts_serving/stream_client.py

@ -0,0 +1,281 @@
import json
import sys
import locale
import codecs
import threading
import triton_python_backend_utils as pb_utils
import math
import time
import numpy as np
import onnxruntime as ort
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend.zh_frontend import Frontend
voc_block = 36
voc_pad = 14
am_block = 72
am_pad = 12
voc_upsample = 300
# 模型路径
phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt"
am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy"
dir_name = "/models/paddle_speech/1/"
onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx"
onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx"
onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx"
onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx"
frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
am_mu, am_std = np.load(am_stat_path)
# 用CPU推理
providers = ['CPUExecutionProvider']
# 配置ort session
sess_options = ort.SessionOptions()
# 创建session
am_encoder_infer_sess = ort.InferenceSession(
onnx_am_encoder, providers=providers, sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
onnx_am_decoder, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
onnx_am_postnet, providers=providers, sess_options=sess_options)
voc_melgan_sess = ort.InferenceSession(
onnx_voc_melgan, providers=providers, sess_options=sess_options)
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to intialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
print(sys.getdefaultencoding())
# You must parse model_config. JSON string is not parsed here
self.model_config = model_config = json.loads(args['model_config'])
print("model_config:", self.model_config)
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
model_config)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(args['model_name']))
self.input_names = []
for input_config in self.model_config["input"]:
self.input_names.append(input_config["name"])
print("input:", self.input_names)
self.output_names = []
self.output_dtype = []
for output_config in self.model_config["output"]:
self.output_names.append(output_config["name"])
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
self.output_dtype.append(dtype)
print("output:", self.output_names)
# To keep track of response threads so that we can delay
# the finalizing the model until all response threads
# have completed.
self.inflight_thread_count = 0
self.inflight_thread_count_lck = threading.Lock()
def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference is requested
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse.
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
# This model does not support batching, so 'request_count' should always
# be 1.
if len(requests) != 1:
raise pb_utils.TritonModelException("unsupported batch size " +
len(requests))
input_data = []
for idx in range(len(self.input_names)):
data = pb_utils.get_input_tensor_by_name(requests[0],
self.input_names[idx])
data = data.as_numpy()
data = data[0].decode('utf-8')
input_data.append(data)
text = input_data[0]
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(target=self.response_thread,
args=(requests[0].get_response_sender(),text)
)
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
thread.start()
# Unlike in non-decoupled model transaction policy, execute function
# here returns no response. A return from this function only notifies
# Triton that the model instance is ready to receive another request. As
# we are not waiting for the response thread to complete here, it is
# possible that at any give time the model may be processing multiple
# requests. Depending upon the request workload, this may lead to a lot
# of requests being processed by a single model instance at a time. In
# real-world models, the developer should be mindful of when to return
# from execute and be willing to accept next request.
return None
def response_thread(self, response_sender, text):
input_ids = frontend.get_input_ids(text, merge_sentences=False, get_tone_ids=False)
phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i].numpy()
voc_chunk_id = 0
orig_hs = am_encoder_infer_sess.run(None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / voc_block)
start = 0
end = min(voc_block + voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, am_block, am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = am_decoder_sess.run(None, input_feed={'xs': hs})
am_postnet_output = am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
})
am_output_data = am_decoder_output + np.transpose(am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, am_mu, am_std)
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, 1)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate((mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num):
voc_chunk = mel_streaming[start:end, :]
sub_wav = voc_melgan_sess.run(
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
voc_block, voc_pad, voc_upsample)
output_np = np.array(sub_wav, dtype=self.output_dtype[0])
out_tensor1 = pb_utils.Tensor(self.output_names[0], output_np)
status = 0 if voc_chunk_id != (voc_chunk_num-1) else 1
output_status = np.array([status], dtype=self.output_dtype[1])
out_tensor2 = pb_utils.Tensor(self.output_names[1], output_status)
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor1,out_tensor2])
#yield sub_wav
response_sender.send(inference_response)
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len)
# We must close the response sender to indicate to Triton that we are
# done sending responses for the corresponding request. We can't use the
# response sender after closing it. The response sender is closed by
# setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
response_sender.send(
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1
def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is OPTIONAL. This function allows
the model to perform any necessary clean ups before exit.
Here we will wait for all response threads to complete sending
responses.
"""
print('Finalize invoked')
inflight_threads = True
cycles = 0
logging_time_sec = 5
sleep_time_sec = 0.1
cycle_to_log = (logging_time_sec / sleep_time_sec)
while inflight_threads:
with self.inflight_thread_count_lck:
inflight_threads = (self.inflight_thread_count != 0)
if (cycles % cycle_to_log == 0):
print(
f"Waiting for {self.inflight_thread_count} response threads to complete..."
)
if inflight_threads:
time.sleep(sleep_time_sec)
cycles += 1
print('Finalize complete...')

@ -0,0 +1,33 @@
name: "streaming_tts_serving"
backend: "python"
max_batch_size: 0
model_transaction_policy {
decoupled: True
}
input [
{
name: "INPUT_0"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "OUTPUT_0"
data_type: TYPE_FP32
dims: [ -1, 1 ]
},
{
name: "status"
data_type: TYPE_BOOL
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

@ -0,0 +1,126 @@
#!/usr/bin/env python
from functools import partial
import argparse
import numpy as np
import sys
import queue
import uuid
import tritonclient.grpc as grpcclient
from tritonclient.utils import *
FLAGS = None
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def async_stream_send(triton_client, values, request_id,
model_name):
infer_inputs = []
outputs = []
for idx, data in enumerate(values):
data = np.array([data.encode('utf-8')],
dtype=np.object_)
infer_input = grpcclient.InferInput('INPUT_0', [len(data)],
"BYTES")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
# Issue the asynchronous sequence inference.
triton_client.async_stream_infer(model_name=model_name,
inputs=infer_inputs,
outputs=outputs,
request_id=request_id)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument(
'-u',
'--url',
type=str,
required=False,
default='localhost:8001',
help='Inference server URL and it gRPC port. Default is localhost:8001.'
)
FLAGS = parser.parse_args()
# We use custom "sequence" models which take 1 input
# value. The output is the accumulated value of the inputs. See
# src/custom/sequence.
model_name = "streaming_tts_serving"
values = ["哈哈哈哈"]
request_id = "0"
#string_sequence_id0 = str(uuid.uuid4())
string_result0_list = []
user_data = UserData()
# It is advisable to use client object within with..as clause
# when sending streaming requests. This ensures the client
# is closed when the block inside with exits.
with grpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose) as triton_client:
try:
# Establish stream
triton_client.start_stream(callback=partial(callback, user_data))
# Now send the inference sequences...
async_stream_send(triton_client,
values,
request_id,
model_name)
except InferenceServerException as error:
print(error)
sys.exit(1)
# Retrieve results...
recv_count = 0
result_dict = {}
status = True
while True:
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
else:
this_id = data_item.get_response().id
if this_id not in result_dict.keys():
result_dict[this_id] = []
result_dict[this_id].append((recv_count, data_item))
sub_wav = data_item.as_numpy('OUTPUT_0')
status = data_item.as_numpy('status')
print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
print('status = ', status)
if status[0] == True:
break
recv_count += 1
print("PASS: stream_client")
Loading…
Cancel
Save