fix code style

pull/2528/head
HexToString 3 years ago
parent 6a14b1bb95
commit 7cf94a3693

@ -1,22 +1,18 @@
import codecs
import json
import math
import sys
import locale
import codecs
import threading
import triton_python_backend_utils as pb_utils
import math
import time
import numpy as np
import numpy as np
import onnxruntime as ort
import triton_python_backend_utils as pb_utils
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend.zh_frontend import Frontend
voc_block = 36
voc_pad = 14
am_block = 72
@ -52,6 +48,7 @@ am_postnet_sess = ort.InferenceSession(
voc_melgan_sess = ort.InferenceSession(
onnx_voc_melgan, providers=providers, sess_options=sess_options)
def depadding(data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
@ -69,6 +66,7 @@ def depadding(data, chunk_num, chunk_id, block, pad, upsample):
return data
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
@ -146,8 +144,8 @@ class TritonPythonModel:
# This model does not support batching, so 'request_count' should always
# be 1.
if len(requests) != 1:
raise pb_utils.TritonModelException("unsupported batch size " +
len(requests))
raise pb_utils.TritonModelException("unsupported batch size " + len(
requests))
input_data = []
for idx in range(len(self.input_names)):
@ -160,9 +158,9 @@ class TritonPythonModel:
# Start a separate thread to send the responses for the request. The
# sending back the responses is delegated to this thread.
thread = threading.Thread(target=self.response_thread,
args=(requests[0].get_response_sender(),text)
)
thread = threading.Thread(
target=self.response_thread,
args=(requests[0].get_response_sender(), text))
thread.daemon = True
with self.inflight_thread_count_lck:
self.inflight_thread_count += 1
@ -180,13 +178,15 @@ class TritonPythonModel:
return None
def response_thread(self, response_sender, text):
input_ids = frontend.get_input_ids(text, merge_sentences=False, get_tone_ids=False)
input_ids = frontend.get_input_ids(
text, merge_sentences=False, get_tone_ids=False)
phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i].numpy()
voc_chunk_id = 0
orig_hs = am_encoder_infer_sess.run(None, input_feed={'text': part_phone_ids})
orig_hs = am_encoder_infer_sess.run(
None, input_feed={'text': part_phone_ids})
orig_hs = orig_hs[0]
# streaming voc chunk info
@ -199,27 +199,31 @@ class TritonPythonModel:
hss = get_chunks(orig_hs, am_block, am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = am_decoder_sess.run(None, input_feed={'xs': hs})
am_decoder_output = am_decoder_sess.run(
None, input_feed={'xs': hs})
am_postnet_output = am_postnet_sess.run(
None,
input_feed={
'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
})
am_output_data = am_decoder_output + np.transpose(am_postnet_output[0], (0, 2, 1))
am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
sub_mel = denorm(normalized_mel, am_mu, am_std)
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad, 1)
sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad,
1)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate((mel_streaming, sub_mel), axis=0)
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num):
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
voc_chunk = mel_streaming[start:end, :]
sub_wav = voc_melgan_sess.run(
@ -228,13 +232,17 @@ class TritonPythonModel:
voc_block, voc_pad, voc_upsample)
output_np = np.array(sub_wav, dtype=self.output_dtype[0])
out_tensor1 = pb_utils.Tensor(self.output_names[0], output_np)
out_tensor1 = pb_utils.Tensor(self.output_names[0],
output_np)
status = 0 if voc_chunk_id != (voc_chunk_num-1) else 1
output_status = np.array([status], dtype=self.output_dtype[1])
out_tensor2 = pb_utils.Tensor(self.output_names[1], output_status)
status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1
output_status = np.array(
[status], dtype=self.output_dtype[1])
out_tensor2 = pb_utils.Tensor(self.output_names[1],
output_status)
inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor1,out_tensor2])
inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor1, out_tensor2])
#yield sub_wav
response_sender.send(inference_response)

@ -1,12 +1,10 @@
#!/usr/bin/env python
from functools import partial
import argparse
import numpy as np
import sys
import queue
import uuid
import sys
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import *
@ -14,7 +12,6 @@ FLAGS = None
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
@ -31,22 +28,20 @@ def callback(user_data, result, error):
user_data._completed_requests.put(result)
def async_stream_send(triton_client, values, request_id,
model_name):
def async_stream_send(triton_client, values, request_id, model_name):
infer_inputs = []
outputs = []
for idx, data in enumerate(values):
data = np.array([data.encode('utf-8')],
dtype=np.object_)
infer_input = grpcclient.InferInput('INPUT_0', [len(data)],
"BYTES")
data = np.array([data.encode('utf-8')], dtype=np.object_)
infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
# Issue the asynchronous sequence inference.
triton_client.async_stream_infer(model_name=model_name,
triton_client.async_stream_infer(
model_name=model_name,
inputs=infer_inputs,
outputs=outputs,
request_id=request_id)
@ -54,7 +49,8 @@ def async_stream_send(triton_client, values, request_id,
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v',
parser.add_argument(
'-v',
'--verbose',
action="store_true",
required=False,
@ -66,8 +62,7 @@ if __name__ == '__main__':
type=str,
required=False,
default='localhost:8001',
help='Inference server URL and it gRPC port. Default is localhost:8001.'
)
help='Inference server URL and it gRPC port. Default is localhost:8001.')
FLAGS = parser.parse_args()
@ -94,10 +89,7 @@ if __name__ == '__main__':
# Establish stream
triton_client.start_stream(callback=partial(callback, user_data))
# Now send the inference sequences...
async_stream_send(triton_client,
values,
request_id,
model_name)
async_stream_send(triton_client, values, request_id, model_name)
except InferenceServerException as error:
print(error)
sys.exit(1)
@ -119,7 +111,7 @@ if __name__ == '__main__':
status = data_item.as_numpy('status')
print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
print('status = ', status)
if status[0] == True:
if status[0] is True:
break
recv_count += 1

Loading…
Cancel
Save