diff --git a/README.md b/README.md index 2321920de..32e1c23d8 100644 --- a/README.md +++ b/README.md @@ -981,6 +981,7 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P - Many thanks to [jerryuhoo](https://github.com/jerryuhoo)/[VTuberTalk](https://github.com/jerryuhoo/VTuberTalk) for developing a GUI tool based on PaddleSpeech TTS and code for making datasets from videos based on PaddleSpeech ASR. - Many thanks to [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) for developing a rasa chatbot,which is able to speak and listen thanks to PaddleSpeech. - Many thanks to [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) for the C++ inference implementation of PaddleSpeech ASR. +- Many thanks to [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) for the real-time voice typing tool implementation of PaddleSpeech ASR streaming services. Besides, PaddleSpeech depends on a lot of open source repositories. See [references](./docs/source/reference.md) for more information. diff --git a/README_cn.md b/README_cn.md index 8127c5570..427d59caf 100644 --- a/README_cn.md +++ b/README_cn.md @@ -987,6 +987,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - 非常感谢 [vpegasus](https://github.com/vpegasus)/[xuesebot](https://github.com/vpegasus/xuesebot) 基于 PaddleSpeech 的 ASR 与 TTS 设计的可听、说对话机器人。 - 非常感谢 [chenkui164](https://github.com/chenkui164)/[FastASR](https://github.com/chenkui164/FastASR) 对 PaddleSpeech 的 ASR 进行 C++ 推理实现。 +- 非常感谢 [heyudage](https://github.com/heyudage)/[VoiceTyping](https://github.com/heyudage/VoiceTyping) 基于 PaddleSpeech 的 ASR 流式服务实现的实时语音输入法工具。 此外,PaddleSpeech 依赖于许多开源存储库。有关更多信息,请参阅 [references](./docs/source/reference.md)。 diff --git a/demos/speech_ssl/README.md b/demos/speech_ssl/README.md index fdef37e7b..b98a7cc61 100644 --- a/demos/speech_ssl/README.md +++ b/demos/speech_ssl/README.md @@ -82,7 +82,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav Output: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 + i knocked at the door on the ancient side of the building Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, diff --git a/demos/speech_ssl/README_cn.md b/demos/speech_ssl/README_cn.md index 76ec2f1ff..65961ce90 100644 --- a/demos/speech_ssl/README_cn.md +++ b/demos/speech_ssl/README_cn.md @@ -36,9 +36,9 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav ``` 参数: - `input`(必须输入):用于识别的音频文件。 - - `model`:ASR 任务的模型,默认值:`conformer_wenetspeech`。 + - `model`:ASR 任务的模型,默认值:`wav2vec2ASR_librispeech`。 - `task`:输出类别,默认值:`asr`。 - - `lang`:模型语言,默认值:`zh`。 + - `lang`:模型语言,默认值:`en`。 - `sample_rate`:音频采样率,默认值:`16000`。 - `config`:ASR 任务的参数文件,若不设置则使用预训练模型中的默认配置,默认值:`None`。 - `ckpt_path`:模型参数文件,若不设置则下载预训练模型使用,默认值:`None`。 @@ -83,8 +83,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav 输出: ```bash ASR Result: - 我认为跑步最重要的就是给我带来了身体健康 - + i knocked at the door on the ancient side of the building + Representation: Tensor(shape=[1, 164, 1024], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[[ 0.02351918, -0.12980647, 0.17868176, ..., 0.10118122, diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index 82313c330..9c88796bb 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -46,7 +46,7 @@ class VanillaNN(containers.Sequential): dnn_neurons=512, activation=True, normalization=False, - dropout_rate=0.0): + dropout_rate=0.5): super().__init__(input_shape=[None, None, input_shape]) if not isinstance(dropout_rate, list): @@ -68,6 +68,5 @@ class VanillaNN(containers.Sequential): if activation: self.append(paddle.nn.LeakyReLU(), layer_name="act") self.append( - paddle.nn.Dropout(), - p=dropout_rate[block_index], + paddle.nn.Dropout(p=dropout_rate[block_index]), layer_name='dropout') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/normalization.py b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py new file mode 100644 index 000000000..912981058 --- /dev/null +++ b/paddlespeech/s2t/models/wav2vec2/modules/normalization.py @@ -0,0 +1,97 @@ +# Authors +# * Mirco Ravanelli 2020 +# * Guillermo Cámbara 2021 +# * Sarthak Yadav 2022 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/normalization.py) +import paddle.nn as nn + +from paddlespeech.s2t.modules.align import BatchNorm1D + + +class BatchNorm1d(nn.Layer): + """Applies 1d batch normalization to the input tensor. + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + Example + ------- + >>> input = paddle.randn([100, 10]) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + Paddle.Shape([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.9, + combine_batch_time=False, + skip_transpose=False, ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = BatchNorm1D(input_size, momentum=momentum, epsilon=eps) + + def forward(self, x): + """Returns the normalized input tensor. + Arguments + --------- + x : paddle.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], + shape_or[2]) + + elif not self.skip_transpose: + x = x.transpose([0, 2, 1]) + + x_n = self.norm(x) + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose([0, 2, 1]) + + return x_n diff --git a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py index 1942e6661..1250e96ca 100644 --- a/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py +++ b/paddlespeech/t2s/frontend/zh_normalization/text_normlization.py @@ -65,7 +65,7 @@ class TextNormalizer(): if lang == "zh": text = text.replace(" ", "") # 过滤掉特殊字符 - text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text) + text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|…\\]', '', text) text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) text = text.strip() sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] @@ -85,7 +85,33 @@ class TextNormalizer(): sentence = sentence.replace('⑧', '八') sentence = sentence.replace('⑨', '九') sentence = sentence.replace('⑩', '十') - + sentence = sentence.replace('α', '阿尔法') + sentence = sentence.replace('β', '贝塔') + sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛') + sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔') + sentence = sentence.replace('ε', '艾普西龙') + sentence = sentence.replace('ζ', '捷塔') + sentence = sentence.replace('η', '依塔') + sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔') + sentence = sentence.replace('ι', '艾欧塔') + sentence = sentence.replace('κ', '喀帕') + sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达') + sentence = sentence.replace('μ', '缪') + sentence = sentence.replace('ν', '拗') + sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西') + sentence = sentence.replace('ο', '欧米克伦') + sentence = sentence.replace('π', '派').replace('Π', '派') + sentence = sentence.replace('ρ', '肉') + sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace( + 'σ', '西格玛') + sentence = sentence.replace('τ', '套') + sentence = sentence.replace('υ', '宇普西龙') + sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾') + sentence = sentence.replace('χ', '器') + sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') + sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') + # re filter special characters, have one more character "-" than line 68 + sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|…\\]', '', sentence) return sentence def normalize_sentence(self, sentence: str) -> str: @@ -124,6 +150,5 @@ class TextNormalizer(): def normalize(self, text: str) -> List[str]: sentences = self._split(text) - sentences = [self.normalize_sentence(sent) for sent in sentences] return sentences diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md index e6ab953c8..b98b74b6f 100644 --- a/speechx/examples/ds2_ol/onnx/README.md +++ b/speechx/examples/ds2_ol/onnx/README.md @@ -1,11 +1,8 @@ -# DeepSpeech2 to ONNX model +# Convert DeepSpeech2 model to ONNX format -1. convert deepspeech2 model to ONNX, using Paddle2ONNX. -2. check paddleinference and onnxruntime output equal. -3. optimize onnx model -4. check paddleinference and optimized onnxruntime output equal. -5. quantize onnx model -4. check paddleinference and optimized onnxruntime output equal. +> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). + +This example demonstrate converting ds2 model to ONNX fromat. Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. @@ -25,18 +22,24 @@ onnxoptimizer 0.2.7 onnxruntime 1.11.0 ``` + ## Using ``` bash run.sh --stage 0 --stop_stage 5 ``` +1. convert deepspeech2 model to ONNX, using Paddle2ONNX. +2. check paddleinference and onnxruntime output equal. +3. optimize onnx model +4. check paddleinference and optimized onnxruntime output equal. +5. quantize onnx model +6. check paddleinference and optimized onnxruntime output equal. + For more details please see `run.sh`. ## Outputs -The optimized onnx model is `exp/model.opt.onnx`, quanted model is `$exp/model.optset11.quant.onnx`. - -To show the graph, please using `local/netron.sh`. +The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`. ## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr) diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/speechx/examples/u2pp_ol/wenetspeech/README.md index b90b8e201..6ca8f6dd8 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/README.md +++ b/speechx/examples/u2pp_ol/wenetspeech/README.md @@ -1,27 +1,77 @@ -# u2/u2pp Streaming ASR +# U2/U2++ Streaming ASR + +A C++ deployment example for `PaddleSpeech/examples/wenetspeech/asr1` recipe. The model is static model from `export`, how to export model please see [here](../../../../examples/wenetspeech/asr1/). If you want using exported model, `run.sh` will download it, for the model link please see `run.sh`. + +This example will demonstrate how to using the u2/u2++ model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data. ## Testing with Aishell Test Data -### Download wav and model +### Source `path.sh` first + +```bash +source path.sh +``` + +All bins are under `echo $SPEECHX_BUILD` dir. + +### Download dataset and model ``` ./run.sh --stop_stage 0 ``` -### compute feature +### process `cmvn` and compute feature -``` +```bash ./run.sh --stage 1 --stop_stage 1 ``` -### decoding using feature +If you only want to convert `cmvn` file format, can using this cmd: + +```bash +./local/feat.sh --stage 1 --stop_stage 1 +``` + +### Decoding using `feature` input ``` ./run.sh --stage 2 --stop_stage 2 ``` -### decoding using wav +### Decoding using `wav` input ``` ./run.sh --stage 3 --stop_stage 3 ``` + +This stage using `u2_recognizer_main` to recognize wav file. + +The input is `scp` file which look like this: +```text +# head data/split1/1/aishell_test.scp +BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav +BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav +... +BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav +``` + +If you want to recognize one wav, you can make `scp` file like this: +```text +key path/to/wav/file +``` + +Then specify `--wav_rspecifier=` param for `u2_recognizer_main` bin. For other flags meaning, please see `help`: +```bash +u2_recognizer_main --help +``` + +The exmaple using `u2_recgonize_main` bin please see `local/recognizer.sh`. + +### Decoding with `wav` using quant model + +`local/recognizer_quant.sh` is same to `local/recognizer.sh`, but using quanted model. + + +## Results + +Please see [here](./RESULTS.md). diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh index 870c5deeb..711d68083 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/run.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh @@ -72,13 +72,16 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # process cmvn and compute fbank feat ./local/feat.sh fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # decode with fbank feat input ./local/decode.sh fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # decode with wav input ./loca/recognizer.sh fi