diff --git a/README.md b/README.md
index e1f57fcaf..49e40624d 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,7 @@
Quick Start
| Documents
| Models List
- | AIStudio Courses
+ | AIStudio Courses
| NAACL2022 Best Demo Award Paper
| Gitee
@@ -179,7 +179,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- Scan the QR code below with your Wechat, you can access to official technical exchange group and get the bonus ( more than 20GB learning materials, such as papers, codes and videos ) and the live link of the lessons. Look forward to your participation.
-

+
## Installation
diff --git a/README_cn.md b/README_cn.md
index 1e932201f..bf3ff4dfd 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -23,7 +23,7 @@
| 快速开始
| 教程文档
| 模型列表
- | AIStudio 课程
+ | AIStudio 课程
| NAACL2022 论文
| Gitee
@@ -162,22 +162,7 @@
- 🧩 级联模型应用: 作为传统语音任务的扩展,我们结合了自然语言处理、计算机视觉等任务,实现更接近实际需求的产业级应用。
-### 近期活动
-
- ❗️重磅❗️飞桨智慧金融行业系列直播课
-✅ 覆盖智能风控、智能运维、智能营销、智能客服四大金融主流场景
-
-📆 9月6日-9月29日每周二、四19:00
-+ 智慧金融行业深入洞察
-+ 8节理论+实践精品直播课
-+ 10+真实产业场景范例教学及实践
-+ 更有免费算力+结业证书等礼品等你来拿
-扫码报名码住直播链接,与行业精英深度交流
-
-
-

-
-
+
### 近期更新
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning.
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
@@ -200,13 +185,13 @@
### 🔥 加入技术交流群获取入群福利
- - 3 日直播课链接: 深度解读 PP-TTS、PP-ASR、PP-VPR 三项核心语音系统关键技术
+ - 3 日直播课链接: 深度解读 【一句话语音合成】【小样本语音合成】【定制化语音识别】语音交互技术
- 20G 学习大礼包:视频课程、前沿论文与学习资料
微信扫描二维码关注公众号,点击“马上报名”填写问卷加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
-

+
diff --git a/demos/streaming_tts_serving_fastdeploy/README.md b/demos/streaming_tts_serving_fastdeploy/README.md
new file mode 100644
index 000000000..3e983a06d
--- /dev/null
+++ b/demos/streaming_tts_serving_fastdeploy/README.md
@@ -0,0 +1,67 @@
+([简体中文](./README_cn.md)|English)
+
+# Streaming Speech Synthesis Service
+
+## Introduction
+This demo is an implementation of starting the streaming speech synthesis service and accessing the service.
+
+`Server` must be started in the docker, while `Client` does not have to be in the docker.
+
+**The streaming_tts_serving under the path of this article ($PWD) contains the configuration and code of the model, which needs to be mapped to the docker for use.**
+
+## Usage
+### 1. Server
+#### 1.1 Docker
+
+```bash
+docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
+docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
+docker exec -it -u root fastdeploy bash
+```
+
+#### 1.2 Installation(inside the docker)
+```bash
+apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
+pip3 install paddlespeech
+export LC_ALL="zh_CN.UTF-8"
+export LANG="zh_CN.UTF-8"
+export LANGUAGE="zh_CN:zh:en_US:en"
+```
+
+#### 1.3 Download models(inside the docker)
+```bash
+cd /models/streaming_tts_serving/1
+wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
+wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
+unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
+unzip mb_melgan_csmsc_onnx_0.2.0.zip
+```
+**For the convenience of users, we recommend that you use the command `docker -v` to map $PWD (streaming_tts_service and the configuration and code of the model contained therein) to the docker path `/models`. You can also use other methods, but regardless of which method you use, the final model directory and structure in the docker are shown in the following figure.**
+
+
+
+
+
+#### 1.4 Start the server(inside the docker)
+
+```bash
+fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving
+```
+Arguments:
+ - `model-repository`(required): Path of model storage.
+ - `model-control-mode`(required): The mode of loading the model. At present, you can use 'explicit'.
+ - `load-model`(required): Name of the model to be loaded.
+ - `http-port`(optional): Port for http service. Default: `8000`. This is not used in our example.
+ - `grpc-port`(optional): Port for grpc service. Default: `8001`.
+ - `metrics-port`(optional): Port for metrics service. Default: `8002`. This is not used in our example.
+
+### 2. Client
+#### 2.1 Installation
+```bash
+pip3 install tritonclient[all]
+```
+
+#### 2.2 Send request
+```bash
+python3 /models/streaming_tts_serving/stream_client.py
+```
diff --git a/demos/streaming_tts_serving_fastdeploy/README_cn.md b/demos/streaming_tts_serving_fastdeploy/README_cn.md
new file mode 100644
index 000000000..7edd32830
--- /dev/null
+++ b/demos/streaming_tts_serving_fastdeploy/README_cn.md
@@ -0,0 +1,67 @@
+(简体中文|[English](./README.md))
+
+# 流式语音合成服务
+
+## 介绍
+
+本文介绍了使用FastDeploy搭建流式语音合成服务的方法。
+
+`服务端`必须在docker内启动,而`客户端`不是必须在docker容器内.
+
+**本文所在路径`($PWD)下的streaming_tts_serving里包含模型的配置和代码`(服务端会加载模型和代码以启动服务),需要将其映射到docker中使用。**
+
+## 使用
+### 1. 服务端
+#### 1.1 Docker
+```bash
+docker pull registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
+docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/models registry.baidubce.com/paddlepaddle/fastdeploy_serving_cpu_only:22.09
+docker exec -it -u root fastdeploy bash
+```
+
+#### 1.2 安装(在docker内)
+```bash
+apt-get install build-essential python3-dev libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libsndfile1 language-pack-zh-hans wget zip
+pip3 install paddlespeech
+export LC_ALL="zh_CN.UTF-8"
+export LANG="zh_CN.UTF-8"
+export LANGUAGE="zh_CN:zh:en_US:en"
+```
+
+#### 1.3 下载模型(在docker内)
+```bash
+cd /models/streaming_tts_serving/1
+wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
+wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip
+unzip fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip
+unzip mb_melgan_csmsc_onnx_0.2.0.zip
+```
+**为了方便用户使用,我们推荐用户使用1.1中的`docker -v`命令将`$PWD(streaming_tts_serving及里面包含的模型的配置和代码)映射到了docker内的/models路径`,用户也可以使用其他办法,但无论使用哪种方法,最终在docker内的模型目录及结构如下图所示。**
+
+
+
+
+
+#### 1.4 启动服务端(在docker内)
+```bash
+fastdeployserver --model-repository=/models --model-control-mode=explicit --load-model=streaming_tts_serving
+```
+
+参数:
+ - `model-repository`(required): 整套模型streaming_tts_serving存放的路径.
+ - `model-control-mode`(required): 模型加载的方式,现阶段, 使用'explicit'即可.
+ - `load-model`(required): 需要加载的模型的名称.
+ - `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口.
+ - `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`.
+ - `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口.
+
+### 2. 客户端
+#### 2.1 安装
+```bash
+pip3 install tritonclient[all]
+```
+
+#### 2.2 发送请求
+```bash
+python3 /models/streaming_tts_serving/stream_client.py
+```
diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py
new file mode 100644
index 000000000..46473fdb2
--- /dev/null
+++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/1/model.py
@@ -0,0 +1,289 @@
+import codecs
+import json
+import math
+import sys
+import threading
+import time
+
+import numpy as np
+import onnxruntime as ort
+import triton_python_backend_utils as pb_utils
+
+from paddlespeech.server.utils.util import denorm
+from paddlespeech.server.utils.util import get_chunks
+from paddlespeech.t2s.frontend.zh_frontend import Frontend
+
+voc_block = 36
+voc_pad = 14
+am_block = 72
+am_pad = 12
+voc_upsample = 300
+
+# 模型路径
+dir_name = "/models/streaming_tts_serving/1/"
+phones_dict = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/phone_id_map.txt"
+am_stat_path = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/speech_stats.npy"
+
+onnx_am_encoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_encoder_infer.onnx"
+onnx_am_decoder = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_decoder.onnx"
+onnx_am_postnet = dir_name + "fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0/fastspeech2_csmsc_am_postnet.onnx"
+onnx_voc_melgan = dir_name + "mb_melgan_csmsc_onnx_0.2.0/mb_melgan_csmsc.onnx"
+
+frontend = Frontend(phone_vocab_path=phones_dict, tone_vocab_path=None)
+am_mu, am_std = np.load(am_stat_path)
+
+# 用CPU推理
+providers = ['CPUExecutionProvider']
+
+# 配置ort session
+sess_options = ort.SessionOptions()
+
+# 创建session
+am_encoder_infer_sess = ort.InferenceSession(
+ onnx_am_encoder, providers=providers, sess_options=sess_options)
+am_decoder_sess = ort.InferenceSession(
+ onnx_am_decoder, providers=providers, sess_options=sess_options)
+am_postnet_sess = ort.InferenceSession(
+ onnx_am_postnet, providers=providers, sess_options=sess_options)
+voc_melgan_sess = ort.InferenceSession(
+ onnx_voc_melgan, providers=providers, sess_options=sess_options)
+
+
+def depadding(data, chunk_num, chunk_id, block, pad, upsample):
+ """
+ Streaming inference removes the result of pad inference
+ """
+ front_pad = min(chunk_id * block, pad)
+ # first chunk
+ if chunk_id == 0:
+ data = data[:block * upsample]
+ # last chunk
+ elif chunk_id == chunk_num - 1:
+ data = data[front_pad * upsample:]
+ # middle chunk
+ else:
+ data = data[front_pad * upsample:(front_pad + block) * upsample]
+
+ return data
+
+
+class TritonPythonModel:
+ """Your Python model must use the same class name. Every Python model
+ that is created must have "TritonPythonModel" as the class name.
+ """
+
+ def initialize(self, args):
+ """`initialize` is called only once when the model is being loaded.
+ Implementing `initialize` function is optional. This function allows
+ the model to intialize any state associated with this model.
+ Parameters
+ ----------
+ args : dict
+ Both keys and values are strings. The dictionary keys and values are:
+ * model_config: A JSON string containing the model configuration
+ * model_instance_kind: A string containing model instance kind
+ * model_instance_device_id: A string containing model instance device ID
+ * model_repository: Model repository path
+ * model_version: Model version
+ * model_name: Model name
+ """
+ sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())
+ print(sys.getdefaultencoding())
+ # You must parse model_config. JSON string is not parsed here
+ self.model_config = model_config = json.loads(args['model_config'])
+ print("model_config:", self.model_config)
+
+ using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
+ model_config)
+
+ if not using_decoupled:
+ raise pb_utils.TritonModelException(
+ """the model `{}` can generate any number of responses per request,
+ enable decoupled transaction policy in model configuration to
+ serve this model""".format(args['model_name']))
+
+ self.input_names = []
+ for input_config in self.model_config["input"]:
+ self.input_names.append(input_config["name"])
+ print("input:", self.input_names)
+
+ self.output_names = []
+ self.output_dtype = []
+ for output_config in self.model_config["output"]:
+ self.output_names.append(output_config["name"])
+ dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
+ self.output_dtype.append(dtype)
+ print("output:", self.output_names)
+
+ # To keep track of response threads so that we can delay
+ # the finalizing the model until all response threads
+ # have completed.
+ self.inflight_thread_count = 0
+ self.inflight_thread_count_lck = threading.Lock()
+
+ def execute(self, requests):
+ """`execute` must be implemented in every Python model. `execute`
+ function receives a list of pb_utils.InferenceRequest as the only
+ argument. This function is called when an inference is requested
+ for this model. Depending on the batching configuration (e.g. Dynamic
+ Batching) used, `requests` may contain multiple requests. Every
+ Python model, must create one pb_utils.InferenceResponse for every
+ pb_utils.InferenceRequest in `requests`. If there is an error, you can
+ set the error argument when creating a pb_utils.InferenceResponse.
+ Parameters
+ ----------
+ requests : list
+ A list of pb_utils.InferenceRequest
+ Returns
+ -------
+ list
+ A list of pb_utils.InferenceResponse. The length of this list must
+ be the same as `requests`
+ """
+
+ # This model does not support batching, so 'request_count' should always
+ # be 1.
+ if len(requests) != 1:
+ raise pb_utils.TritonModelException("unsupported batch size " + len(
+ requests))
+
+ input_data = []
+ for idx in range(len(self.input_names)):
+ data = pb_utils.get_input_tensor_by_name(requests[0],
+ self.input_names[idx])
+ data = data.as_numpy()
+ data = data[0].decode('utf-8')
+ input_data.append(data)
+ text = input_data[0]
+
+ # Start a separate thread to send the responses for the request. The
+ # sending back the responses is delegated to this thread.
+ thread = threading.Thread(
+ target=self.response_thread,
+ args=(requests[0].get_response_sender(), text))
+ thread.daemon = True
+ with self.inflight_thread_count_lck:
+ self.inflight_thread_count += 1
+
+ thread.start()
+ # Unlike in non-decoupled model transaction policy, execute function
+ # here returns no response. A return from this function only notifies
+ # Triton that the model instance is ready to receive another request. As
+ # we are not waiting for the response thread to complete here, it is
+ # possible that at any give time the model may be processing multiple
+ # requests. Depending upon the request workload, this may lead to a lot
+ # of requests being processed by a single model instance at a time. In
+ # real-world models, the developer should be mindful of when to return
+ # from execute and be willing to accept next request.
+ return None
+
+ def response_thread(self, response_sender, text):
+ input_ids = frontend.get_input_ids(
+ text, merge_sentences=False, get_tone_ids=False)
+ phone_ids = input_ids["phone_ids"]
+ for i in range(len(phone_ids)):
+ part_phone_ids = phone_ids[i].numpy()
+ voc_chunk_id = 0
+
+ orig_hs = am_encoder_infer_sess.run(
+ None, input_feed={'text': part_phone_ids})
+ orig_hs = orig_hs[0]
+
+ # streaming voc chunk info
+ mel_len = orig_hs.shape[1]
+ voc_chunk_num = math.ceil(mel_len / voc_block)
+ start = 0
+ end = min(voc_block + voc_pad, mel_len)
+
+ # streaming am
+ hss = get_chunks(orig_hs, am_block, am_pad, "am")
+ am_chunk_num = len(hss)
+ for i, hs in enumerate(hss):
+ am_decoder_output = am_decoder_sess.run(
+ None, input_feed={'xs': hs})
+ am_postnet_output = am_postnet_sess.run(
+ None,
+ input_feed={
+ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
+ })
+ am_output_data = am_decoder_output + np.transpose(
+ am_postnet_output[0], (0, 2, 1))
+ normalized_mel = am_output_data[0][0]
+
+ sub_mel = denorm(normalized_mel, am_mu, am_std)
+ sub_mel = depadding(sub_mel, am_chunk_num, i, am_block, am_pad,
+ 1)
+
+ if i == 0:
+ mel_streaming = sub_mel
+ else:
+ mel_streaming = np.concatenate(
+ (mel_streaming, sub_mel), axis=0)
+
+ # streaming voc
+ # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理
+ while (mel_streaming.shape[0] >= end and
+ voc_chunk_id < voc_chunk_num):
+ voc_chunk = mel_streaming[start:end, :]
+
+ sub_wav = voc_melgan_sess.run(
+ output_names=None, input_feed={'logmel': voc_chunk})
+ sub_wav = depadding(sub_wav[0], voc_chunk_num, voc_chunk_id,
+ voc_block, voc_pad, voc_upsample)
+
+ output_np = np.array(sub_wav, dtype=self.output_dtype[0])
+ out_tensor1 = pb_utils.Tensor(self.output_names[0],
+ output_np)
+
+ status = 0 if voc_chunk_id != (voc_chunk_num - 1) else 1
+ output_status = np.array(
+ [status], dtype=self.output_dtype[1])
+ out_tensor2 = pb_utils.Tensor(self.output_names[1],
+ output_status)
+
+ inference_response = pb_utils.InferenceResponse(
+ output_tensors=[out_tensor1, out_tensor2])
+
+ #yield sub_wav
+ response_sender.send(inference_response)
+
+ voc_chunk_id += 1
+ start = max(0, voc_chunk_id * voc_block - voc_pad)
+ end = min((voc_chunk_id + 1) * voc_block + voc_pad, mel_len)
+
+ # We must close the response sender to indicate to Triton that we are
+ # done sending responses for the corresponding request. We can't use the
+ # response sender after closing it. The response sender is closed by
+ # setting the TRITONSERVER_RESPONSE_COMPLETE_FINAL.
+ response_sender.send(
+ flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
+
+ with self.inflight_thread_count_lck:
+ self.inflight_thread_count -= 1
+
+ def finalize(self):
+ """`finalize` is called only once when the model is being unloaded.
+ Implementing `finalize` function is OPTIONAL. This function allows
+ the model to perform any necessary clean ups before exit.
+ Here we will wait for all response threads to complete sending
+ responses.
+ """
+ print('Finalize invoked')
+
+ inflight_threads = True
+ cycles = 0
+ logging_time_sec = 5
+ sleep_time_sec = 0.1
+ cycle_to_log = (logging_time_sec / sleep_time_sec)
+ while inflight_threads:
+ with self.inflight_thread_count_lck:
+ inflight_threads = (self.inflight_thread_count != 0)
+ if (cycles % cycle_to_log == 0):
+ print(
+ f"Waiting for {self.inflight_thread_count} response threads to complete..."
+ )
+ if inflight_threads:
+ time.sleep(sleep_time_sec)
+ cycles += 1
+
+ print('Finalize complete...')
diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt
new file mode 100644
index 000000000..e63721d1c
--- /dev/null
+++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/config.pbtxt
@@ -0,0 +1,33 @@
+name: "streaming_tts_serving"
+backend: "python"
+max_batch_size: 0
+model_transaction_policy {
+ decoupled: True
+}
+input [
+ {
+ name: "INPUT_0"
+ data_type: TYPE_STRING
+ dims: [ 1 ]
+ }
+]
+
+output [
+ {
+ name: "OUTPUT_0"
+ data_type: TYPE_FP32
+ dims: [ -1, 1 ]
+ },
+ {
+ name: "status"
+ data_type: TYPE_BOOL
+ dims: [ 1 ]
+ }
+]
+
+instance_group [
+ {
+ count: 1
+ kind: KIND_CPU
+ }
+]
diff --git a/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py
new file mode 100644
index 000000000..e7f120b7d
--- /dev/null
+++ b/demos/streaming_tts_serving_fastdeploy/streaming_tts_serving/stream_client.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+import argparse
+import queue
+import sys
+from functools import partial
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+from tritonclient.utils import *
+
+FLAGS = None
+
+
+class UserData:
+ def __init__(self):
+ self._completed_requests = queue.Queue()
+
+
+# Define the callback function. Note the last two parameters should be
+# result and error. InferenceServerClient would povide the results of an
+# inference as grpcclient.InferResult in result. For successful
+# inference, error will be None, otherwise it will be an object of
+# tritonclientutils.InferenceServerException holding the error details
+def callback(user_data, result, error):
+ if error:
+ user_data._completed_requests.put(error)
+ else:
+ user_data._completed_requests.put(result)
+
+
+def async_stream_send(triton_client, values, request_id, model_name):
+
+ infer_inputs = []
+ outputs = []
+ for idx, data in enumerate(values):
+ data = np.array([data.encode('utf-8')], dtype=np.object_)
+ infer_input = grpcclient.InferInput('INPUT_0', [len(data)], "BYTES")
+ infer_input.set_data_from_numpy(data)
+ infer_inputs.append(infer_input)
+
+ outputs.append(grpcclient.InferRequestedOutput('OUTPUT_0'))
+ # Issue the asynchronous sequence inference.
+ triton_client.async_stream_infer(
+ model_name=model_name,
+ inputs=infer_inputs,
+ outputs=outputs,
+ request_id=request_id)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-v',
+ '--verbose',
+ action="store_true",
+ required=False,
+ default=False,
+ help='Enable verbose output')
+ parser.add_argument(
+ '-u',
+ '--url',
+ type=str,
+ required=False,
+ default='localhost:8001',
+ help='Inference server URL and it gRPC port. Default is localhost:8001.')
+
+ FLAGS = parser.parse_args()
+
+ # We use custom "sequence" models which take 1 input
+ # value. The output is the accumulated value of the inputs. See
+ # src/custom/sequence.
+ model_name = "streaming_tts_serving"
+
+ values = ["哈哈哈哈"]
+
+ request_id = "0"
+
+ string_result0_list = []
+
+ user_data = UserData()
+
+ # It is advisable to use client object within with..as clause
+ # when sending streaming requests. This ensures the client
+ # is closed when the block inside with exits.
+ with grpcclient.InferenceServerClient(
+ url=FLAGS.url, verbose=FLAGS.verbose) as triton_client:
+ try:
+ # Establish stream
+ triton_client.start_stream(callback=partial(callback, user_data))
+ # Now send the inference sequences...
+ async_stream_send(triton_client, values, request_id, model_name)
+ except InferenceServerException as error:
+ print(error)
+ sys.exit(1)
+
+ # Retrieve results...
+ recv_count = 0
+ result_dict = {}
+ status = True
+ while True:
+ data_item = user_data._completed_requests.get()
+ if type(data_item) == InferenceServerException:
+ raise data_item
+ else:
+ this_id = data_item.get_response().id
+ if this_id not in result_dict.keys():
+ result_dict[this_id] = []
+ result_dict[this_id].append((recv_count, data_item))
+ sub_wav = data_item.as_numpy('OUTPUT_0')
+ status = data_item.as_numpy('status')
+ print('sub_wav = ', sub_wav, "subwav.shape = ", sub_wav.shape)
+ print('status = ', status)
+ if status[0] == 1:
+ break
+ recv_count += 1
+
+ print("PASS: stream_client")
diff --git a/demos/streaming_tts_serving_fastdeploy/tree.png b/demos/streaming_tts_serving_fastdeploy/tree.png
new file mode 100644
index 000000000..b8d61686a
Binary files /dev/null and b/demos/streaming_tts_serving_fastdeploy/tree.png differ
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 4e76da033..2f3c9d098 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -9,7 +9,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
-[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.4.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
+[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py
index 437f64631..004143361 100644
--- a/paddlespeech/cli/asr/infer.py
+++ b/paddlespeech/cli/asr/infer.py
@@ -52,7 +52,7 @@ class ASRExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
- default='conformer_u2pp_wenetspeech',
+ default='conformer_u2pp_online_wenetspeech',
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
@@ -470,7 +470,7 @@ class ASRExecutor(BaseExecutor):
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
- model: str='conformer_u2pp_wenetspeech',
+ model: str='conformer_u2pp_online_wenetspeech',
lang: str='zh',
sample_rate: int=16000,
config: os.PathLike=None,
diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py
index f5ec655b7..8e9ecc4ba 100644
--- a/paddlespeech/resource/model_alias.py
+++ b/paddlespeech/resource/model_alias.py
@@ -25,7 +25,6 @@ model_alias = {
"deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
- "conformer_u2pp": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],
diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py
index efd6bb3f2..df50a6a9d 100644
--- a/paddlespeech/resource/pretrained_models.py
+++ b/paddlespeech/resource/pretrained_models.py
@@ -68,32 +68,12 @@ asr_dynamic_pretrained_models = {
'',
},
},
- "conformer_u2pp_wenetspeech-zh-16k": {
- '1.1': {
- 'url':
- 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.3.model.tar.gz',
- 'md5':
- '662b347e1d2131b7a4dc5398365e2134',
- 'cfg_path':
- 'model.yaml',
- 'ckpt_path':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10',
- 'model':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
- 'params':
- 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
- 'lm_url':
- '',
- 'lm_md5':
- '',
- },
- },
"conformer_u2pp_online_wenetspeech-zh-16k": {
- '1.1': {
+ '1.3': {
'url':
- 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.4.model.tar.gz',
+ 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz',
'md5':
- '3100fc1eac5779486cab859366992d0b',
+ '62d230c1bf27731192aa9d3b8deca300',
'cfg_path':
'model.yaml',
'ckpt_path':
diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py
index 128f87c07..d9568dcc9 100644
--- a/paddlespeech/s2t/modules/attention.py
+++ b/paddlespeech/s2t/modules/attention.py
@@ -19,7 +19,6 @@ from typing import Tuple
import paddle
from paddle import nn
-from paddle.nn import functional as F
from paddle.nn import initializer as I
from paddlespeech.s2t.modules.align import Linear
@@ -56,16 +55,6 @@ class MultiHeadedAttention(nn.Layer):
self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)
- def _build_once(self, *args, **kwargs):
- super()._build_once(*args, **kwargs)
- # if self.self_att:
- # self.linear_kv = Linear(self.n_feat, self.n_feat*2)
- if not self.training:
- self.weight = paddle.concat(
- [self.linear_k.weight, self.linear_v.weight], axis=-1)
- self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
- self._built = True
-
def forward_qkv(self,
query: paddle.Tensor,
key: paddle.Tensor,
@@ -87,13 +76,8 @@ class MultiHeadedAttention(nn.Layer):
n_batch = query.shape[0]
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
- if self.training:
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
- else:
- k, v = F.linear(key, self.weight, self.bias).view(
- n_batch, -1, 2 * self.h, self.d_k).split(
- 2, axis=2)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py
index 10a91d9be..1b1792bd1 100644
--- a/paddlespeech/server/bin/paddlespeech_server.py
+++ b/paddlespeech/server/bin/paddlespeech_server.py
@@ -113,7 +113,7 @@ class ServerExecutor(BaseExecutor):
"""
config = get_config(config_file)
if self.init(config):
- uvicorn.run(app, host=config.host, port=config.port, debug=True)
+ uvicorn.run(app, host=config.host, port=config.port)
@cli_server_register(
diff --git a/paddlespeech/t2s/frontend/polyphonic.yaml b/paddlespeech/t2s/frontend/polyphonic.yaml
index 51b76f23f..6885035e7 100644
--- a/paddlespeech/t2s/frontend/polyphonic.yaml
+++ b/paddlespeech/t2s/frontend/polyphonic.yaml
@@ -46,4 +46,5 @@ polyphonic:
幸免于难: ['xing4','mian3','yu2','nan4']
恶行: ['e4','xing2']
唉: ['ai4']
-
+ 扎实: ['zha1','shi2']
+ 干将: ['gan4','jiang4']
\ No newline at end of file
diff --git a/paddlespeech/t2s/frontend/tone_sandhi.py b/paddlespeech/t2s/frontend/tone_sandhi.py
index 10a9540c3..42f7b8b2f 100644
--- a/paddlespeech/t2s/frontend/tone_sandhi.py
+++ b/paddlespeech/t2s/frontend/tone_sandhi.py
@@ -65,7 +65,7 @@ class ToneSandhi():
'男子', '女子', '分子', '原子', '量子', '莲子', '石子', '瓜子', '电子', '人人', '虎虎',
'幺幺', '干嘛', '学子', '哈哈', '数数', '袅袅', '局地', '以下', '娃哈哈', '花花草草', '留得',
'耕地', '想想', '熙熙', '攘攘', '卵子', '死死', '冉冉', '恳恳', '佼佼', '吵吵', '打打',
- '考考', '整整', '莘莘'
+ '考考', '整整', '莘莘', '落地', '算子', '家家户户'
}
self.punc = ":,;。?!“”‘’':,;.?!"
diff --git a/paddlespeech/text/exps/ernie_linear/punc_restore.py b/paddlespeech/text/exps/ernie_linear/punc_restore.py
index 2cb4d0719..98804606c 100644
--- a/paddlespeech/text/exps/ernie_linear/punc_restore.py
+++ b/paddlespeech/text/exps/ernie_linear/punc_restore.py
@@ -25,8 +25,6 @@ DefinedClassifier = {
'ErnieLinear': ErnieLinear,
}
-tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
-
def _clean_text(text, punc_list):
text = text.lower()
@@ -35,7 +33,7 @@ def _clean_text(text, punc_list):
return text
-def preprocess(text, punc_list):
+def preprocess(text, punc_list, tokenizer):
clean_text = _clean_text(text, punc_list)
assert len(clean_text) > 0, f'Invalid input string: {text}'
tokenized_input = tokenizer(
@@ -51,7 +49,8 @@ def test(args):
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
- print(yaml.safe_dump(vars(args)))
+ print(yaml.safe_dump(vars(args), allow_unicode=True))
+ # print(args)
print("========Config========")
print(config)
@@ -61,10 +60,16 @@ def test(args):
punc_list.append(line.strip())
model = DefinedClassifier[config["model_type"]](**config["model"])
+ # print(model)
+
+ pretrained_token = config['data_params']['pretrained_token']
+ tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
+ # tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')
+
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"])
model.eval()
- _inputs = preprocess(args.text, punc_list)
+ _inputs = preprocess(args.text, punc_list, tokenizer)
seq_len = _inputs['seq_len']
input_ids = paddle.to_tensor(_inputs['input_ids']).unsqueeze(0)
seg_ids = paddle.to_tensor(_inputs['seg_ids']).unsqueeze(0)