fix code style

pull/2528/head
HexToString 3 years ago
parent 6a14b1bb95
commit 7cf94a3693

@ -1,22 +1,18 @@
import codecs
import json import json
import math
import sys import sys
import locale
import codecs
import threading import threading
import triton_python_backend_utils as pb_utils
import math
import time import time
import numpy as np
import numpy as np
import onnxruntime as ort import onnxruntime as ort
import triton_python_backend_utils as pb_utils
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
voc_block = 36 voc_block = 36
voc_pad = 14 voc_pad = 14
am_block = 72 am_block = 72
@ -52,6 +48,7 @@ am_postnet_sess = ort.InferenceSession(
voc_melgan_sess = ort.InferenceSession( voc_melgan_sess = ort.InferenceSession(
onnx_voc_melgan, providers=providers, sess_options=sess_options) onnx_voc_melgan, providers=providers, sess_options=sess_options)
def depadding(data, chunk_num, chunk_id, block, pad, upsample): def depadding(data, chunk_num, chunk_id, block, pad, upsample):
""" """
Streaming inference removes the result of pad inference Streaming inference removes the result of pad inference
@ -69,6 +66,7 @@ def depadding(data, chunk_num, chunk_id, block, pad, upsample):
return data return data
class TritonPythonModel: class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model """Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name. that is created must have "TritonPythonModel" as the class name.
@ -146,8 +144,8 @@ class TritonPythonModel:
# This model does not support batching, so 'request_count' should always # This model does not support batching, so 'request_count' should always
# be 1. # be 1.
if len(requests) != 1: if len(requests) != 1:
raise pb_utils.TritonModelException("unsupported batch size " + raise pb_utils.TritonModelException("unsupported batch size " + len(
len(requests)) requests))
input_data = [] input_data = []
for idx in range(len(self.input_names)): for idx in range(len(self.input_names)):
@ -160,9 +158,9 @@ class TritonPythonModel:
# Start a separate thread to send the responses for the request. The # Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread. # sending back the responses is delegated to this thread.
thread = threading.Thread(target=self.response_thread, thread = threading.Thread(
args=(requests[0].get_response_sender(),text) target=self.response_thread,
) args=(requests[0].get_response_sender(), text))
thread.daemon = True thread.daemon = True
with self.inflight_thread_count_lck: with self.inflight_thread_count_lck:
self.inflight_thread_count += 1 self.inflight_thread_count += 1
@ -180,13 +178,15 @@ class TritonPythonModel:
return None return None
def response_thread(self, response_sender, text): def response_thread(self, response_sender, text):
input_ids = frontend.get_input_ids(text, merge_sentences=False, get_tone_ids=False) input_ids = frontend.get_input_ids(
text, merge_sentences=False, get_tone_ids=False)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i].numpy() part_phone_ids = phone_ids[i].numpy()
voc_chunk_id = 0 voc_chunk_id = 0
orig_hs = am_encoder_infer_sess.run(None, input_feed={'text': part_phone_ids}) orig_hs = am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0] orig_hs = orig_hs[0]
# streaming voc chunk info # streaming voc chunk info
@ -199,27 +199,31 @@ class TritonPythonModel:
hss = get_chunks(orig_hs, am_block, am_pad, "am") hss = get_chunks(orig_hs, am_block, am_pad, "am")
am_chunk_num = len(hss) am_chunk_num = len(hss)
for i, hs in enumerate(hss): for i, hs in enumerate(hss):
am_decoder_output = am_decoder_sess.run(None, input_feed={'xs': hs}) am_decoder_output = am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = am_postnet_sess.run( am_postnet_output = am_postnet_sess.run(
None, None,
input_feed={ input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) 'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
}) })
am_output_data = am_decoder_output + np.transpose(am_postnet_output[0], (0, 2, 1)) am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0] normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, am_mu, am_std) sub_mel = denorm(normalized_mel, am_mu, am_std)
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, 1) sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad,
1)
if i == 0: if i == 0:
mel_streaming = sub_mel mel_streaming = sub_mel
else: else:
mel_streaming = np.concatenate((mel_streaming, sub_mel), axis=0) mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc # streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理 # 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num): while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
voc_chunk = mel_streaming[start:end, :] voc_chunk = mel_streaming[start:end, :]
sub_wav = voc_melgan_sess.run( sub_wav = voc_melgan_sess.run(
@ -228,13 +232,17 @@ class TritonPythonModel:
voc_block, voc_pad, voc_upsample) voc_block, voc_pad, voc_upsample)
output_np = np.array(sub_wav, dtype=self.output_dtype[0]) output_np = np.array(sub_wav, dtype=self.output_dtype[0])
out_tensor1 = pb_utils.Tensor(self.output_names[0], output_np) out_tensor1 = pb_utils.Tensor(self.output_names[0],
output_np)
status = 0 if voc_chunk_id != (voc_chunk_num-1) else 1 status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1
output_status = np.array([status], dtype=self.output_dtype[1]) output_status = np.array(
out_tensor2 = pb_utils.Tensor(self.output_names[1], output_status) [status], dtype=self.output_dtype[1])
out_tensor2 = pb_utils.Tensor(self.output_names[1],
output_status)
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor1,out_tensor2]) inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor1, out_tensor2])
#yield sub_wav #yield sub_wav
response_sender.send(inference_response) response_sender.send(inference_response)
@ -252,7 +260,7 @@ class TritonPythonModel:
with self.inflight_thread_count_lck: with self.inflight_thread_count_lck:
self.inflight_thread_count -= 1 self.inflight_thread_count -= 1
def finalize(self): def finalize(self):
"""`finalize` is called only once when the model is being unloaded. """`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is OPTIONAL. This function allows Implementing `finalize` function is OPTIONAL. This function allows

@ -1,12 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
from functools import partial
import argparse import argparse
import numpy as np
import sys
import queue import queue
import uuid import sys
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from tritonclient.utils import * from tritonclient.utils import *
@ -14,7 +12,6 @@ FLAGS = None
class UserData: class UserData:
def __init__(self): def __init__(self):
self._completed_requests = queue.Queue() self._completed_requests = queue.Queue()
@ -31,43 +28,41 @@ def callback(user_data, result, error):
user_data._completed_requests.put(result) user_data._completed_requests.put(result)
def async_stream_send(triton_client, values, request_id, def async_stream_send(triton_client, values, request_id, model_name):
model_name):
infer_inputs = [] infer_inputs = []
outputs = [] outputs = []
for idx, data in enumerate(values): for idx, data in enumerate(values):
data = np.array([data.encode('utf-8')], data = np.array([data.encode('utf-8')], dtype=np.object_)
dtype=np.object_) infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES")
infer_input = grpcclient.InferInput('INPUT_0', [len(data)],
"BYTES")
infer_input.set_data_from_numpy(data) infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input) infer_inputs.append(infer_input)
outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0')) outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
# Issue the asynchronous sequence inference. # Issue the asynchronous sequence inference.
triton_client.async_stream_infer(model_name=model_name, triton_client.async_stream_infer(
inputs=infer_inputs, model_name=model_name,
outputs=outputs, inputs=infer_inputs,
request_id=request_id) outputs=outputs,
request_id=request_id)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-v', parser.add_argument(
'--verbose', '-v',
action="store_true", '--verbose',
required=False, action="store_true",
default=False, required=False,
help='Enable verbose output') default=False,
help='Enable verbose output')
parser.add_argument( parser.add_argument(
'-u', '-u',
'--url', '--url',
type=str, type=str,
required=False, required=False,
default='localhost:8001', default='localhost:8001',
help='Inference server URL and it gRPC port. Default is localhost:8001.' help='Inference server URL and it gRPC port. Default is localhost:8001.')
)
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
@ -94,10 +89,7 @@ if __name__ == '__main__':
# Establish stream # Establish stream
triton_client.start_stream(callback=partial(callback, user_data)) triton_client.start_stream(callback=partial(callback, user_data))
# Now send the inference sequences... # Now send the inference sequences...
async_stream_send(triton_client, async_stream_send(triton_client, values, request_id, model_name)
values,
request_id,
model_name)
except InferenceServerException as error: except InferenceServerException as error:
print(error) print(error)
sys.exit(1) sys.exit(1)
@ -119,7 +111,7 @@ if __name__ == '__main__':
status = data_item.as_numpy('status') status = data_item.as_numpy('status')
print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape) print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
print('status = ', status) print('status = ', status)
if status[0] == True: if status[0] is True:
break break
recv_count += 1 recv_count += 1

Loading…
Cancel
Save