diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py index 323bb09d5..f165bf287 100644 --- a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py +++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py @@ -1,22 +1,18 @@ +import codecs import json +import math import sys -import locale -import codecs import threading -import triton_python_backend_utils as pb_utils - -import math import time -import numpy as np +import numpy as np import onnxruntime as ort +import triton_python_backend_utils as pb_utils -from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks from paddlespeech.t2s.frontend.zh_frontend import Frontend - voc_block = 36 voc_pad = 14 am_block = 72 @@ -52,6 +48,7 @@ am_postnet_sess = ort.InferenceSession( voc_melgan_sess = ort.InferenceSession( onnx_voc_melgan, providers=providers, sess_options=sess_options) + def depadding(data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -69,6 +66,7 @@ def depadding(data, chunk_num, chunk_id, block, pad, upsample): return data + class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. @@ -146,8 +144,8 @@ class TritonPythonModel: # This model does not support batching, so 'request_count' should always # be 1. if len(requests) != 1: - raise pb_utils.TritonModelException("unsupported batch size " + - len(requests)) + raise pb_utils.TritonModelException("unsupported batch size " + len( + requests)) input_data = [] for idx in range(len(self.input_names)): @@ -160,9 +158,9 @@ class TritonPythonModel: # Start a separate thread to send the responses for the request. The # sending back the responses is delegated to this thread. - thread = threading.Thread(target=self.response_thread, - args=(requests[0].get_response_sender(),text) - ) + thread = threading.Thread( + target=self.response_thread, + args=(requests[0].get_response_sender(), text)) thread.daemon = True with self.inflight_thread_count_lck: self.inflight_thread_count += 1 @@ -180,13 +178,15 @@ class TritonPythonModel: return None def response_thread(self, response_sender, text): - input_ids = frontend.get_input_ids(text, merge_sentences=False, get_tone_ids=False) + input_ids = frontend.get_input_ids( + text, merge_sentences=False, get_tone_ids=False) phone_ids = input_ids["phone_ids"] for i in range(len(phone_ids)): part_phone_ids = phone_ids[i].numpy() voc_chunk_id = 0 - orig_hs = am_encoder_infer_sess.run(None, input_feed={'text': part_phone_ids}) + orig_hs = am_encoder_infer_sess.run( + None, input_feed={'text': part_phone_ids}) orig_hs = orig_hs[0] # streaming voc chunk info @@ -199,27 +199,31 @@ class TritonPythonModel: hss = get_chunks(orig_hs, am_block, am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): - am_decoder_output = am_decoder_sess.run(None, input_feed={'xs': hs}) + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': hs}) am_postnet_output = am_postnet_sess.run( None, input_feed={ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) }) - am_output_data = am_decoder_output + np.transpose(am_postnet_output[0], (0, 2, 1)) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] sub_mel = denorm(normalized_mel, am_mu, am_std) - sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, 1) + sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, + 1) if i == 0: mel_streaming = sub_mel else: - mel_streaming = np.concatenate((mel_streaming, sub_mel), axis=0) - + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) # streaming voc # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 - while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num): + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): voc_chunk = mel_streaming[start:end, :] sub_wav = voc_melgan_sess.run( @@ -228,13 +232,17 @@ class TritonPythonModel: voc_block, voc_pad, voc_upsample) output_np = np.array(sub_wav, dtype=self.output_dtype[0]) - out_tensor1 = pb_utils.Tensor(self.output_names[0], output_np) + out_tensor1 = pb_utils.Tensor(self.output_names[0], + output_np) - status = 0 if voc_chunk_id != (voc_chunk_num-1) else 1 - output_status = np.array([status], dtype=self.output_dtype[1]) - out_tensor2 = pb_utils.Tensor(self.output_names[1], output_status) + status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1 + output_status = np.array( + [status], dtype=self.output_dtype[1]) + out_tensor2 = pb_utils.Tensor(self.output_names[1], + output_status) - inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor1,out_tensor2]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor1, out_tensor2]) #yield sub_wav response_sender.send(inference_response) @@ -252,7 +260,7 @@ class TritonPythonModel: with self.inflight_thread_count_lck: self.inflight_thread_count -= 1 - + def finalize(self): """`finalize` is called only once when the model is being unloaded. Implementing `finalize` function is OPTIONAL. This function allows diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py index ce58de5ed..0d36eaaf9 100644 --- a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py +++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py @@ -1,12 +1,10 @@ #!/usr/bin/env python - -from functools import partial import argparse -import numpy as np -import sys import queue -import uuid +import sys +from functools import partial +import numpy as np import tritonclient.grpc as grpcclient from tritonclient.utils import * @@ -14,7 +12,6 @@ FLAGS = None class UserData: - def __init__(self): self._completed_requests = queue.Queue() @@ -31,43 +28,41 @@ def callback(user_data, result, error): user_data._completed_requests.put(result) -def async_stream_send(triton_client, values, request_id, - model_name): +def async_stream_send(triton_client, values, request_id, model_name): infer_inputs = [] outputs = [] for idx, data in enumerate(values): - data = np.array([data.encode('utf-8')], - dtype=np.object_) - infer_input = grpcclient.InferInput('INPUT_0', [len(data)], - "BYTES") + data = np.array([data.encode('utf-8')], dtype=np.object_) + infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES") infer_input.set_data_from_numpy(data) infer_inputs.append(infer_input) outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0')) # Issue the asynchronous sequence inference. - triton_client.async_stream_infer(model_name=model_name, - inputs=infer_inputs, - outputs=outputs, - request_id=request_id) + triton_client.async_stream_infer( + model_name=model_name, + inputs=infer_inputs, + outputs=outputs, + request_id=request_id) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-v', - '--verbose', - action="store_true", - required=False, - default=False, - help='Enable verbose output') + parser.add_argument( + '-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') parser.add_argument( '-u', '--url', type=str, required=False, default='localhost:8001', - help='Inference server URL and it gRPC port. Default is localhost:8001.' - ) + help='Inference server URL and it gRPC port. Default is localhost:8001.') FLAGS = parser.parse_args() @@ -94,10 +89,7 @@ if __name__ == '__main__': # Establish stream triton_client.start_stream(callback=partial(callback, user_data)) # Now send the inference sequences... - async_stream_send(triton_client, - values, - request_id, - model_name) + async_stream_send(triton_client, values, request_id, model_name) except InferenceServerException as error: print(error) sys.exit(1) @@ -119,7 +111,7 @@ if __name__ == '__main__': status = data_item.as_numpy('status') print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape) print('status = ', status) - if status[0] == True: + if status[0] is True: break recv_count += 1