diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6e7ae1fbf..0435cfbe1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -50,13 +50,20 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
- exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
+ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker
# name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook
# language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
+ - id: cpplint
+ name: cpplint
+ description: Static code analysis of C/C++ files
+ language: python
+ files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
+ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
+ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
diff --git a/README.md b/README.md
index 49e40624d..c54f6fed1 100644
--- a/README.md
+++ b/README.md
@@ -157,6 +157,8 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
+- 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS.
+- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web).
- ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder.
@@ -826,7 +828,21 @@ The Text-to-Speech module is originally called [Parakeet](https://github.com/Pad
## Citation
To cite PaddleSpeech for research, please use the following format.
-```tex
+```text
+@InProceedings{pmlr-v162-bai22d,
+ title = {{A}$^3${T}: Alignment-Aware Acoustic and Text Pretraining for Speech Synthesis and Editing},
+ author = {Bai, He and Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Li, Xintong and Huang, Liang},
+ booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+ pages = {1399--1411},
+ year = {2022},
+ volume = {162},
+ series = {Proceedings of Machine Learning Research},
+ month = {17--23 Jul},
+ publisher = {PMLR},
+ pdf = {https://proceedings.mlr.press/v162/bai22d/bai22d.pdf},
+ url = {https://proceedings.mlr.press/v162/bai22d.html},
+}
+
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
@@ -923,8 +939,8 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
## Acknowledgement
- Many thanks to [HighCWu](https://github.com/HighCWu) for adding [VITS-aishell3](./examples/aishell3/vits) and [VITS-VC](./examples/aishell3/vits-vc) examples.
-- Many thanks to [david-95](https://github.com/david-95) improved TTS, fixed multi-punctuation bug, and contributed to multiple program and data.
-- Many thanks to [BarryKCL](https://github.com/BarryKCL) improved TTS Chinses frontend based on [G2PW](https://github.com/GitYCC/g2pW).
+- Many thanks to [david-95](https://github.com/david-95) for fixing multi-punctuation bug、contributing to multiple program and data, and adding [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
+- Many thanks to [BarryKCL](https://github.com/BarryKCL) for improving TTS Chinses Frontend based on [G2PW](https://github.com/GitYCC/g2pW).
- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help.
- Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files.
- Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function.
diff --git a/README_cn.md b/README_cn.md
index bf3ff4dfd..406aa6abc 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -164,7 +164,9 @@
### 近期更新
-- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning.
+- 🔥 2022.10.26: TTS 新增[韵律预测](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy)功能。
+- 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
+- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
- ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。
@@ -832,6 +834,20 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
要引用 PaddleSpeech 进行研究,请使用以下格式进行引用。
```text
+@InProceedings{pmlr-v162-bai22d,
+ title = {{A}$^3${T}: Alignment-Aware Acoustic and Text Pretraining for Speech Synthesis and Editing},
+ author = {Bai, He and Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Li, Xintong and Huang, Liang},
+ booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+ pages = {1399--1411},
+ year = {2022},
+ volume = {162},
+ series = {Proceedings of Machine Learning Research},
+ month = {17--23 Jul},
+ publisher = {PMLR},
+ pdf = {https://proceedings.mlr.press/v162/bai22d/bai22d.pdf},
+ url = {https://proceedings.mlr.press/v162/bai22d.html},
+}
+
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
@@ -928,7 +944,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
## 致谢
- 非常感谢 [HighCWu](https://github.com/HighCWu) 新增 [VITS-aishell3](./examples/aishell3/vits) 和 [VITS-VC](./examples/aishell3/vits-vc) 代码示例。
-- 非常感谢 [david-95](https://github.com/david-95) 修复句尾多标点符号出错的问题,贡献补充多条程序和数据。
+- 非常感谢 [david-95](https://github.com/david-95) 修复 TTS 句尾多标点符号出错的问题,贡献补充多条程序和数据。为 TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 非常感谢 [BarryKCL](https://github.com/BarryKCL) 基于 [G2PW](https://github.com/GitYCC/g2pW) 对 TTS 中文文本前端的优化。
- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。
- 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。
diff --git a/examples/other/rhy/README.md b/examples/other/rhy/README.md
new file mode 100644
index 000000000..11336ad9f
--- /dev/null
+++ b/examples/other/rhy/README.md
@@ -0,0 +1,41 @@
+# Prosody Prediction with CSMSC and AISHELL-3
+
+## Get Started
+### Data Preprocessing
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Model Training
+```bash
+./run.sh --stage 1 --stop-stage 1
+```
+### Testing
+```bash
+./run.sh --stage 2 --stop-stage 2
+```
+### Prosody Prediction
+```bash
+./run.sh --stage 3 --stop-stage 3
+```
+## Pretrained Model
+The pretrained model can be downloaded here:
+
+[ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip)
+
+And you should put it into `exp/${YOUREXP}/checkpoints` folder.
+
+## Rhythm mapping
+Four punctuation marks are used to denote the rhythm marks respectively:
+|ryh_token|csmsc|aishll3|
+|:---: |:---: |:---: |
+|%|#1|%|
+|`|#2||
+|~|#3||
+|$|#4|$|
+
+## Prediction Results
+| | #1 | #2 | #3 | #4 |
+|:-----:|:-----:|:-----:|:-----:|:-----:|
+|Precision |0.90 |0.66 |0.91 |0.90|
+|Recall |0.92 |0.62 |0.83 |0.85|
+|F1 |0.91 |0.64 |0.87 |0.87|
diff --git a/examples/other/rhy/conf/default.yaml b/examples/other/rhy/conf/default.yaml
new file mode 100644
index 000000000..1eb90f11f
--- /dev/null
+++ b/examples/other/rhy/conf/default.yaml
@@ -0,0 +1,44 @@
+###########################################################
+# DATA SETTING #
+###########################################################
+dataset_type: Ernie
+train_path: data/train.txt
+dev_path: data/dev.txt
+test_path: data/test.txt
+batch_size: 64
+num_workers: 2
+data_params:
+ pretrained_token: ernie-1.0
+ punc_path: data/rhy_token
+ seq_len: 100
+
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+model_type: ErnieLinear
+model:
+ pretrained_token: ernie-1.0
+ num_classes: 5
+
+###########################################################
+# OPTIMIZER SETTING #
+###########################################################
+optimizer_params:
+ weight_decay: 1.0e-6 # weight decay coefficient.
+
+scheduler_params:
+ learning_rate: 1.0e-5 # learning rate.
+ gamma: 0.9999 # scheduler gamma must between(0.0, 1.0) and closer to 1.0 is better.
+
+###########################################################
+# TRAINING SETTING #
+###########################################################
+max_epoch: 20
+num_snapshots: 5
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/other/rhy/data/rhy_token b/examples/other/rhy/data/rhy_token
new file mode 100644
index 000000000..bf1fe253f
--- /dev/null
+++ b/examples/other/rhy/data/rhy_token
@@ -0,0 +1,4 @@
+%
+`
+~
+$
\ No newline at end of file
diff --git a/examples/other/rhy/local/data.sh b/examples/other/rhy/local/data.sh
new file mode 100755
index 000000000..93b134873
--- /dev/null
+++ b/examples/other/rhy/local/data.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+if [ ! -f 000001-010000.txt ]; then
+ wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/000001-010000.txt
+fi
+
+if [ ! -f label_train-set.txt ]; then
+ wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/label_train-set.txt
+fi
+
+
+aishell_data=$1
+csmsc_data=$2
+processed_path=$3
+
+python3 ./local/pre_for_sp_csmsc.py \
+ --data=${csmsc_data} \
+ --processed_path=${processed_path}
+
+python3 ./local/pre_for_sp_aishell.py \
+ --data=${aishell_data} \
+ --processed_path=${processed_path}
+
+
+echo "Finish data preparation."
+exit 0
diff --git a/examples/other/rhy/local/pre_for_sp_aishell.py b/examples/other/rhy/local/pre_for_sp_aishell.py
new file mode 100644
index 000000000..a2a716683
--- /dev/null
+++ b/examples/other/rhy/local/pre_for_sp_aishell.py
@@ -0,0 +1,51 @@
+import argparse
+import os
+import re
+
+# This is the replacement for rhythm labels to predict.
+# 韵律标签的代替
+replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
+
+
+def replace_rhy_with_punc(line):
+ # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
+ line = re.sub(r'[:、,;。?!,.:;"?!’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line)
+ for r in replace_.keys():
+ if r in line:
+ line = line.replace(r, replace_[r])
+ return line
+
+
+def pre_and_write(data, file):
+ with open(file, 'a') as rf:
+ for d in data:
+ d = d.split('|')[2].strip()
+ # d = replace_rhy_with_punc(d)
+ d = ' '.join(d) + ' \n'
+ rf.write(d)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Train a Rhy prediction model.")
+ parser.add_argument("--data", type=str, default="label_train-set.txt")
+ parser.add_argument(
+ "--processed_path", type=str, default="../data/rhy_predict")
+ args = parser.parse_args()
+ os.makedirs(args.processed_path, exist_ok=True)
+
+ with open(args.data) as rf:
+ text = rf.readlines()[5:]
+ len_ = len(text)
+ lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
+ files = ['train.txt', 'test.txt', 'dev.txt']
+
+ i = 0
+ for l_, file in zip(lens, files):
+ file = os.path.join(args.processed_path, file)
+ pre_and_write(text[i:i + l_], file)
+ i = i + l_
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/other/rhy/local/pre_for_sp_csmsc.py b/examples/other/rhy/local/pre_for_sp_csmsc.py
new file mode 100644
index 000000000..0a96092c1
--- /dev/null
+++ b/examples/other/rhy/local/pre_for_sp_csmsc.py
@@ -0,0 +1,51 @@
+import argparse
+import os
+import re
+
+replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
+
+
+def replace_rhy_with_punc(line):
+ # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
+ line = re.sub(r'^$\*%', '', line)
+ for r in replace_.keys():
+ if r in line:
+ line = line.replace(r, replace_[r])
+ return line
+
+
+def pre_and_write(data, file):
+ with open(file, 'w') as rf:
+ for d in data:
+ d = d.split('\t')[1].strip()
+ d = replace_rhy_with_punc(d)
+ d = ' '.join(d) + ' \n'
+ rf.write(d)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Train a Rhy prediction model.")
+ parser.add_argument("--data", type=str, default="label_train-set.txt")
+ parser.add_argument(
+ "--processed_path", type=str, default="../data/rhy_predict")
+ args = parser.parse_args()
+ print(args.data, args.processed_path)
+ os.makedirs(args.processed_path, exist_ok=True)
+
+ with open(args.data) as rf:
+ rf = rf.readlines()
+ text = rf[0::2]
+ len_ = len(text)
+ lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
+ files = ['train.txt', 'test.txt', 'dev.txt']
+
+ i = 0
+ for l_, file in zip(lens, files):
+ file = os.path.join(args.processed_path, file)
+ pre_and_write(text[i:i + l_], file)
+ i = i + l_
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/other/rhy/local/rhy_predict.sh b/examples/other/rhy/local/rhy_predict.sh
new file mode 100755
index 000000000..30a4f12f8
--- /dev/null
+++ b/examples/other/rhy/local/rhy_predict.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+text=$4
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/punc_restore.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --text=${text}
diff --git a/examples/other/rhy/local/test.sh b/examples/other/rhy/local/test.sh
new file mode 100755
index 000000000..bd490b5b9
--- /dev/null
+++ b/examples/other/rhy/local/test.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+print_eval=$4
+
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/test.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --print_eval=${print_eval}
\ No newline at end of file
diff --git a/examples/other/rhy/local/train.sh b/examples/other/rhy/local/train.sh
new file mode 100755
index 000000000..85227eacb
--- /dev/null
+++ b/examples/other/rhy/local/train.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+
+python3 ${BIN_DIR}/train.py \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/other/rhy/path.sh b/examples/other/rhy/path.sh
new file mode 100755
index 000000000..da790261f
--- /dev/null
+++ b/examples/other/rhy/path.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+export MAIN_ROOT=${PWD}/../../../
+
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
+export LC_ALL=C
+
+# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
+
+export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
+
+MODEL=ernie_linear
+export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL}
diff --git a/examples/other/rhy/run.sh b/examples/other/rhy/run.sh
new file mode 100755
index 000000000..aed58152e
--- /dev/null
+++ b/examples/other/rhy/run.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set -e
+source path.sh
+
+gpus=0
+stage=0
+stop_stage=100
+
+aishell_data=label_train-set.txt
+csmsc_data=000001-010000.txt
+processed_path=data
+
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_2600.pdz
+text=我们城市的复苏有赖于他强有力的政策。
+print_eval=false
+
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ ./local/data.sh ${aishell_data} ${csmsc_data} ${processed_path}
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/rhy_predict.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
+fi
\ No newline at end of file
diff --git a/examples/wenetspeech/asr1/RESULTS.md b/examples/wenetspeech/asr1/RESULTS.md
index f22c652e6..cd480163e 100644
--- a/examples/wenetspeech/asr1/RESULTS.md
+++ b/examples/wenetspeech/asr1/RESULTS.md
@@ -53,3 +53,22 @@ Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 |
+
+
+## U2PP Steaming Pretrained Model
+
+Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz
+
+| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | 16 | 0.057031 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | 16 | 0.068826 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | 16 | 0.069111 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | 16 | 0.059213 |
+
+| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
+| --- | --- | --- | --- | --- | --- | --- | --- |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | -1 | 0.049256 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.052086 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.052267 |
+| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.047198 |
diff --git a/examples/wenetspeech/asr1/local/test_wav.sh b/examples/wenetspeech/asr1/local/test_wav.sh
index 474642624..c3a17f491 100755
--- a/examples/wenetspeech/asr1/local/test_wav.sh
+++ b/examples/wenetspeech/asr1/local/test_wav.sh
@@ -42,6 +42,7 @@ for type in attention_rescoring; do
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \
+ --debug True \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py
index d12ea3646..0df443193 100644
--- a/paddlespeech/s2t/exps/u2/bin/test_wav.py
+++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py
@@ -16,6 +16,8 @@ import os
import sys
from pathlib import Path
+import distutils
+import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
@@ -74,6 +76,8 @@ class U2Infer():
# fbank
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
+ if self.args.debug:
+ np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
@@ -126,6 +130,11 @@ if __name__ == "__main__":
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
+ parser.add_argument(
+ "--debug",
+ type=distutils.util.strtobool,
+ default=False,
+ help="for debug.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
diff --git a/paddlespeech/t2s/__init__.py b/paddlespeech/t2s/__init__.py
index 7d93c026e..57fe82a9c 100644
--- a/paddlespeech/t2s/__init__.py
+++ b/paddlespeech/t2s/__init__.py
@@ -18,5 +18,6 @@ from . import exps
from . import frontend
from . import models
from . import modules
+from . import ssml
from . import training
from . import utils
diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py
index 15d8dfb78..41663891e 100644
--- a/paddlespeech/t2s/exps/syn_utils.py
+++ b/paddlespeech/t2s/exps/syn_utils.py
@@ -13,6 +13,7 @@
# limitations under the License.
import math
import os
+import re
from pathlib import Path
from typing import Any
from typing import Dict
@@ -33,6 +34,7 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
+
# remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3)
@@ -103,14 +105,15 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentences = []
with open(text_file, 'rt') as f:
for line in f:
- items = line.strip().split()
- utt_id = items[0]
- if lang == 'zh':
- sentence = "".join(items[1:])
- elif lang == 'en':
- sentence = " ".join(items[1:])
- elif lang == 'mix':
- sentence = " ".join(items[1:])
+ if line.strip() != "":
+ items = re.split(r"\s+", line.strip(), 1)
+ utt_id = items[0]
+ if lang == 'zh':
+ sentence = "".join(items[1:])
+ elif lang == 'en':
+ sentence = " ".join(items[1:])
+ elif lang == 'mix':
+ sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
return sentences
@@ -180,11 +183,20 @@ def run_frontend(frontend: object,
to_tensor: bool=True):
outs = dict()
if lang == 'zh':
- input_ids = frontend.get_input_ids(
- text,
- merge_sentences=merge_sentences,
- get_tone_ids=get_tone_ids,
- to_tensor=to_tensor)
+ input_ids = {}
+ if text.strip() != "" and re.match(r".*?.*?.*", text,
+ re.DOTALL):
+ input_ids = frontend.get_input_ids_ssml(
+ text,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids,
+ to_tensor=to_tensor)
+ else:
+ input_ids = frontend.get_input_ids(
+ text,
+ merge_sentences=merge_sentences,
+ get_tone_ids=get_tone_ids,
+ to_tensor=to_tensor)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py
index 722eed601..e30286986 100644
--- a/paddlespeech/t2s/frontend/zh_frontend.py
+++ b/paddlespeech/t2s/frontend/zh_frontend.py
@@ -13,6 +13,7 @@
# limitations under the License.
import os
import re
+from operator import itemgetter
from typing import Dict
from typing import List
@@ -31,6 +32,7 @@ from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
+from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor
INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
@@ -81,6 +83,7 @@ class Frontend():
g2p_model="g2pW",
phone_vocab_path=None,
tone_vocab_path=None):
+ self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!"
@@ -281,6 +284,65 @@ class Frontend():
phones_list.append(merge_list)
return phones_list
+ def _split_word_to_char(self, words):
+ res = []
+ for x in words:
+ res.append(x)
+ return res
+
+ # if using ssml, have pingyin specified, assign pinyin to words
+ def _g2p_assign(self,
+ words: List[str],
+ pinyin_spec: List[str],
+ merge_sentences: bool=True) -> List[List[str]]:
+ phones_list = []
+ initials = []
+ finals = []
+
+ words = self._split_word_to_char(words[0])
+ for pinyin, char in zip(pinyin_spec, words):
+ sub_initials = []
+ sub_finals = []
+ pinyin = pinyin.replace("u:", "v")
+ #self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
+ if pinyin in self.pinyin2phone:
+ initial_final_list = self.pinyin2phone[pinyin].split(" ")
+ if len(initial_final_list) == 2:
+ sub_initials.append(initial_final_list[0])
+ sub_finals.append(initial_final_list[1])
+ elif len(initial_final_list) == 1:
+ sub_initials.append('')
+ sub_finals.append(initial_final_list[1])
+ else:
+ # If it's not pinyin (possibly punctuation) or no conversion is required
+ sub_initials.append(pinyin)
+ sub_finals.append(pinyin)
+ initials.append(sub_initials)
+ finals.append(sub_finals)
+
+ initials = sum(initials, [])
+ finals = sum(finals, [])
+ phones = []
+ for c, v in zip(initials, finals):
+ # NOTE: post process for pypinyin outputs
+ # we discriminate i, ii and iii
+ if c and c not in self.punc:
+ phones.append(c)
+ if c and c in self.punc:
+ phones.append('sp')
+ if v and v not in self.punc:
+ phones.append(v)
+ phones_list.append(phones)
+ if merge_sentences:
+ merge_list = sum(phones_list, [])
+ # rm the last 'sp' to avoid the noise at the end
+ # cause in the training data, no 'sp' in the end
+ if merge_list[-1] == 'sp':
+ merge_list = merge_list[:-1]
+ phones_list = []
+ phones_list.append(merge_list)
+ return phones_list
+
def _merge_erhua(self,
initials: List[str],
finals: List[str],
@@ -396,6 +458,52 @@ class Frontend():
print("----------------------------")
return phonemes
+ #@an added for ssml pinyin
+ def get_phonemes_ssml(self,
+ ssml_inputs: list,
+ merge_sentences: bool=True,
+ with_erhua: bool=True,
+ robot: bool=False,
+ print_info: bool=False) -> List[List[str]]:
+ all_phonemes = []
+ for word_pinyin_item in ssml_inputs:
+ phonemes = []
+ sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item)
+ sentences = self.text_normalizer.normalize(sentence)
+ if len(pinyin_spec) == 0:
+ phonemes = self._g2p(
+ sentences,
+ merge_sentences=merge_sentences,
+ with_erhua=with_erhua)
+ else:
+ # phonemes should be pinyin_spec
+ phonemes = self._g2p_assign(
+ sentences, pinyin_spec, merge_sentences=merge_sentences)
+
+ all_phonemes = all_phonemes + phonemes
+
+ if robot:
+ new_phonemes = []
+ for sentence in all_phonemes:
+ new_sentence = []
+ for item in sentence:
+ # `er` only have tone `2`
+ if item[-1] in "12345" and item != "er2":
+ item = item[:-1] + "1"
+ new_sentence.append(item)
+ new_phonemes.append(new_sentence)
+ all_phonemes = new_phonemes
+
+ if print_info:
+ print("----------------------------")
+ print("text norm results:")
+ print(sentences)
+ print("----------------------------")
+ print("g2p results:")
+ print(all_phonemes[0])
+ print("----------------------------")
+ return [sum(all_phonemes, [])]
+
def get_input_ids(self,
sentence: str,
merge_sentences: bool=True,
@@ -405,6 +513,7 @@ class Frontend():
add_blank: bool=False,
blank_token: str="",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
+
phonemes = self.get_phonemes(
sentence,
merge_sentences=merge_sentences,
@@ -437,3 +546,49 @@ class Frontend():
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result
+
+ # @an added for ssml
+ def get_input_ids_ssml(
+ self,
+ sentence: str,
+ merge_sentences: bool=True,
+ get_tone_ids: bool=False,
+ robot: bool=False,
+ print_info: bool=False,
+ add_blank: bool=False,
+ blank_token: str="",
+ to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
+
+ l_inputs = MixTextProcessor.get_pinyin_split(sentence)
+ phonemes = self.get_phonemes_ssml(
+ l_inputs,
+ merge_sentences=merge_sentences,
+ print_info=print_info,
+ robot=robot)
+ result = {}
+ phones = []
+ tones = []
+ temp_phone_ids = []
+ temp_tone_ids = []
+
+ for part_phonemes in phonemes:
+ phones, tones = self._get_phone_tone(
+ part_phonemes, get_tone_ids=get_tone_ids)
+ if add_blank:
+ phones = insert_after_character(phones, blank_token)
+ if tones:
+ tone_ids = self._t2id(tones)
+ if to_tensor:
+ tone_ids = paddle.to_tensor(tone_ids)
+ temp_tone_ids.append(tone_ids)
+ if phones:
+ phone_ids = self._p2id(phones)
+ # if use paddle.to_tensor() in onnxruntime, the first time will be too low
+ if to_tensor:
+ phone_ids = paddle.to_tensor(phone_ids)
+ temp_phone_ids.append(phone_ids)
+ if temp_tone_ids:
+ result["tone_ids"] = temp_tone_ids
+ if temp_phone_ids:
+ result["phone_ids"] = temp_phone_ids
+ return result
diff --git a/paddlespeech/t2s/ssml/__init__.py b/paddlespeech/t2s/ssml/__init__.py
new file mode 100644
index 000000000..9b4db053b
--- /dev/null
+++ b/paddlespeech/t2s/ssml/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .xml_processor import *
diff --git a/paddlespeech/t2s/ssml/xml_processor.py b/paddlespeech/t2s/ssml/xml_processor.py
new file mode 100644
index 000000000..b39121347
--- /dev/null
+++ b/paddlespeech/t2s/ssml/xml_processor.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+import re
+import xml.dom.minidom
+import xml.parsers.expat
+from xml.dom.minidom import Node
+from xml.dom.minidom import parseString
+'''
+Note: xml 有5种特殊字符, &<>"'
+其一,采用特殊标签,将包含特殊字符的字符串封装起来。
+例如:
+
+其二,使用XML转义序列表示这些特殊的字符,这5个特殊字符所对应XML转义序列为:
+& &
+< <
+> >
+" "
+' '
+例如:
+"姓名"
+
+'''
+
+
+class MixTextProcessor():
+ def __repr__(self):
+ print("@an MixTextProcessor class")
+
+ def get_xml_content(self, mixstr):
+ '''返回字符串的 xml 内容'''
+ xmlptn = re.compile(r".*?", re.M | re.S)
+ ctn = re.search(xmlptn, mixstr)
+ if ctn:
+ return ctn.group(0)
+ else:
+ return None
+
+ def get_content_split(self, mixstr):
+ ''' 文本分解,顺序加了列表中,按非 xml 和 xml 分开,对应的字符串,带标点符号
+ 不能去除空格,因为 xml 中tag 属性带空格
+ '''
+ ctlist = []
+ # print("Testing:",mixstr[:20])
+ patn = re.compile(r'(.*\s*?)(.*?)(.*\s*)$', re.M | re.S)
+ mat = re.match(patn, mixstr)
+ if mat:
+ pre_xml = mat.group(1)
+ in_xml = mat.group(2)
+ after_xml = mat.group(3)
+
+ ctlist.append(pre_xml)
+ ctlist.append(in_xml)
+ ctlist.append(after_xml)
+ return ctlist
+ else:
+ ctlist.append(mixstr)
+ return ctlist
+
+ @classmethod
+ def get_pinyin_split(self, mixstr):
+ ctlist = []
+ patn = re.compile(r'(.*\s*?)(.*?)(.*\s*)$', re.M | re.S)
+ mat = re.match(patn, mixstr)
+ if mat:
+ pre_xml = mat.group(1)
+ in_xml = mat.group(2)
+ after_xml = mat.group(3)
+
+ ctlist.append([pre_xml, []])
+ dom = DomXml(in_xml)
+ pinyinlist = dom.get_pinyins_for_xml()
+ ctlist = ctlist + pinyinlist
+ ctlist.append([after_xml, []])
+ else:
+ ctlist.append([mixstr, []])
+ return ctlist
+
+
+class DomXml():
+ def __init__(self, xmlstr):
+ self.tdom = parseString(xmlstr) #Document
+ self.root = self.tdom.documentElement #Element
+ self.rnode = self.tdom.childNodes #NodeList
+
+ def get_text(self):
+ '''返回 xml 内容的所有文本内容的列表'''
+ res = []
+
+ for x1 in self.rnode:
+ if x1.nodeType == Node.TEXT_NODE:
+ res.append(x1.value)
+ else:
+ for x2 in x1.childNodes:
+ if isinstance(x2, xml.dom.minidom.Text):
+ res.append(x2.data)
+ else:
+ for x3 in x2.childNodes:
+ if isinstance(x3, xml.dom.minidom.Text):
+ res.append(x3.data)
+ else:
+ print("len(nodes of x3):", len(x3.childNodes))
+
+ return res
+
+ def get_xmlchild_list(self):
+ '''返回 xml 内容的列表,包括所有文本内容(不带 tag)'''
+ res = []
+
+ for x1 in self.rnode:
+ if x1.nodeType == Node.TEXT_NODE:
+ res.append(x1.value)
+ else:
+ for x2 in x1.childNodes:
+ if isinstance(x2, xml.dom.minidom.Text):
+ res.append(x2.data)
+ else:
+ for x3 in x2.childNodes:
+ if isinstance(x3, xml.dom.minidom.Text):
+ res.append(x3.data)
+ else:
+ print("len(nodes of x3):", len(x3.childNodes))
+ print(res)
+ return res
+
+ def get_pinyins_for_xml(self):
+ '''返回 xml 内容,字符串和拼音的 list '''
+ res = []
+
+ for x1 in self.rnode:
+ if x1.nodeType == Node.TEXT_NODE:
+ t = re.sub(r"\s+", "", x1.value)
+ res.append([t, []])
+ else:
+ for x2 in x1.childNodes:
+ if isinstance(x2, xml.dom.minidom.Text):
+ t = re.sub(r"\s+", "", x2.data)
+ res.append([t, []])
+ else:
+ # print("x2",x2,x2.tagName)
+ if x2.hasAttribute('pinyin'):
+ pinyin_value = x2.getAttribute("pinyin")
+ pinyins = pinyin_value.split(" ")
+ for x3 in x2.childNodes:
+ # print('x3',x3)
+ if isinstance(x3, xml.dom.minidom.Text):
+ t = re.sub(r"\s+", "", x3.data)
+ res.append([t, pinyins])
+ else:
+ print("len(nodes of x3):", len(x3.childNodes))
+
+ return res
+
+ def get_all_tags(self, tag_name):
+ '''获取所有的 tag 及属性值'''
+ alltags = self.root.getElementsByTagName(tag_name)
+ for x in alltags:
+ if x.hasAttribute('pinyin'): # pinyin
+ print(x.tagName, 'pinyin',
+ x.getAttribute('pinyin'), x.firstChild.data)
diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py
index 4302a1a3b..aa172cc69 100644
--- a/paddlespeech/text/exps/ernie_linear/test.py
+++ b/paddlespeech/text/exps/ernie_linear/test.py
@@ -23,6 +23,7 @@ from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
from yacs.config import CfgNode
+from paddlespeech.t2s.utils import str2bool
from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.text.models.ernie_linear import PuncDataset
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
@@ -91,9 +92,10 @@ def test(args):
t = classification_report(
test_total_label, test_total_predict, target_names=punc_list)
print(t)
- t2 = evaluation(test_total_label, test_total_predict)
- print('=========================================================')
- print(t2)
+ if args.print_eval:
+ t2 = evaluation(test_total_label, test_total_predict)
+ print('=========================================================')
+ print(t2)
def main():
@@ -101,6 +103,7 @@ def main():
parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
parser.add_argument("--config", type=str, help="ErnieLinear config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument("--print_eval", type=str2bool, default=True)
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
diff --git a/setup.py b/setup.py
index e551d9fa6..35668bddb 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ base = [
"braceexpand",
"pyyaml",
"pybind11",
+ "paddleslim==2.3.4",
]
server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]
diff --git a/speechx/.clang-format b/speechx/.clang-format
new file mode 100644
index 000000000..af946a4a9
--- /dev/null
+++ b/speechx/.clang-format
@@ -0,0 +1,29 @@
+# This file is used by clang-format to autoformat paddle source code
+#
+# The clang-format is part of llvm toolchain.
+# It need to install llvm and clang to format source code style.
+#
+# The basic usage is,
+# clang-format -i -style=file PATH/TO/SOURCE/CODE
+#
+# The -style=file implicit use ".clang-format" file located in one of
+# parent directory.
+# The -i means inplace change.
+#
+# The document of clang-format is
+# http://clang.llvm.org/docs/ClangFormat.html
+# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
+---
+Language: Cpp
+BasedOnStyle: Google
+IndentWidth: 4
+TabWidth: 4
+ContinuationIndentWidth: 4
+MaxEmptyLinesToKeep: 2
+AccessModifierOffset: -2 # The private/protected/public has no indent in class
+Standard: Cpp11
+AllowAllParametersOfDeclarationOnNextLine: true
+BinPackParameters: false
+BinPackArguments: false
+...
+
diff --git a/speechx/.gitignore b/speechx/.gitignore
index e0c618470..9a93805c0 100644
--- a/speechx/.gitignore
+++ b/speechx/.gitignore
@@ -1 +1,2 @@
tools/valgrind*
+*log
diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt
index 4b5838e5c..978a23d95 100644
--- a/speechx/CMakeLists.txt
+++ b/speechx/CMakeLists.txt
@@ -13,7 +13,6 @@ set(CMAKE_CXX_STANDARD 14)
set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
# Modules
-list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent)
include(ExternalProject)
@@ -32,9 +31,13 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall
###############################################################################
# Option Configurations
###############################################################################
-# option configurations
option(TEST_DEBUG "option for debug" OFF)
+option(USE_PROFILING "enable c++ profling" OFF)
+option(USING_U2 "compile u2 model." ON)
+option(USING_DS2 "compile with ds2 model." ON)
+
+option(USING_GPU "u2 compute on GPU." OFF)
###############################################################################
# Include third party
@@ -83,48 +86,65 @@ add_dependencies(openfst gflags glog)
# paddle lib
-set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
-set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
-ExternalProject_Add(paddle
- URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
- URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
- PREFIX ${paddle_PREFIX_DIR}
- SOURCE_DIR ${paddle_SOURCE_DIR}
- CONFIGURE_COMMAND ""
- BUILD_COMMAND ""
- INSTALL_COMMAND ""
-)
-
-set(PADDLE_LIB ${fc_patch}/paddle-lib)
-include_directories("${PADDLE_LIB}/paddle/include")
-set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
-include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
-include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
-include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
-
-link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
-link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
-link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
-link_directories("${PADDLE_LIB}/paddle/lib")
-link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
-
-##paddle with mkl
-set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
-set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
-include_directories("${MATH_LIB_PATH}/include")
-set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
- ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
-set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
-include_directories("${MKLDNN_PATH}/include")
-set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
-set(EXTERNAL_LIB "-lrt -ldl -lpthread")
-
-set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
-set(DEPS ${DEPS}
- ${MATH_LIB} ${MKLDNN_LIB}
- glog gflags protobuf xxhash cryptopp
- ${EXTERNAL_LIB})
-
+include(paddleinference)
+
+
+# paddle core.so
+find_package(Threads REQUIRED)
+find_package(PythonLibs REQUIRED)
+find_package(Python3 REQUIRED)
+find_package(pybind11 CONFIG)
+
+message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
+message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
+message(STATUS "Pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}, pybind11_LIBRARIES=${pybind11_LIBRARIES}, pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
+
+# paddle include and link option
+# -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
+execute_process(
+ COMMAND python -c "\
+import os;\
+import paddle;\
+include_dir=paddle.sysconfig.get_include();\
+paddle_dir=os.path.split(include_dir)[0];\
+libs_dir=os.path.join(paddle_dir, 'libs');\
+fluid_dir=os.path.join(paddle_dir, 'fluid');\
+out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\
+out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\
+ "
+ OUTPUT_VARIABLE PADDLE_LINK_FLAGS
+ RESULT_VARIABLE SUCESS)
+
+message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
+string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
+
+# paddle compile option
+# -I/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/include
+execute_process(
+ COMMAND python -c "\
+import paddle; \
+include_dir = paddle.sysconfig.get_include(); \
+print(f\"-I{include_dir}\"); \
+ "
+ OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
+message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
+string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
+
+
+# for LD_LIBRARY_PATH
+# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
+execute_process(
+ COMMAND python -c " \
+import os; \
+import paddle; \
+include_dir=paddle.sysconfig.get_include(); \
+paddle_dir=os.path.split(include_dir)[0]; \
+libs_dir=os.path.join(paddle_dir, 'libs'); \
+fluid_dir=os.path.join(paddle_dir, 'fluid'); \
+out=':'.join([libs_dir, fluid_dir]); print(out); \
+ "
+ OUTPUT_VARIABLE PADDLE_LIB_DIRS)
+message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
###############################################################################
diff --git a/speechx/README.md b/speechx/README.md
index cd1cd62c1..f744defae 100644
--- a/speechx/README.md
+++ b/speechx/README.md
@@ -3,11 +3,14 @@
## Environment
We develop under:
+* python - 3.7
* docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7`
* os - Ubuntu 16.04.7 LTS
* gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0
+> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
+
> We make sure all things work fun under docker, and recommend using it to develop and deploy.
* [How to Install Docker](https://docs.docker.com/engine/install/)
@@ -24,16 +27,23 @@ docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --nam
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
+2. Create python environment.
-2. Build `speechx` and `examples`.
+```
+bash tools/venv.sh
+```
-> Do not source venv.
+2. Build `speechx` and `examples`.
+For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
+For example:
```
-pushd /path/to/speechx
+source venv/bin/activate
+python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
./build.sh
```
+
3. Go to `examples` to have a fun.
More details please see `README.md` under `examples`.
diff --git a/speechx/build.sh b/speechx/build.sh
index a6eef6565..7655f9635 100755
--- a/speechx/build.sh
+++ b/speechx/build.sh
@@ -1,4 +1,5 @@
#!/usr/bin/env bash
+set -xe
# the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image.
@@ -17,11 +18,6 @@ fi
#rm -rf build
mkdir -p build
-cd build
-cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
-#cmake ..
-
-make -j
-
-cd -
+cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
+cmake --build build -j
diff --git a/speechx/cmake/external/absl.cmake b/speechx/cmake/absl.cmake
similarity index 100%
rename from speechx/cmake/external/absl.cmake
rename to speechx/cmake/absl.cmake
diff --git a/speechx/cmake/external/boost.cmake b/speechx/cmake/boost.cmake
similarity index 100%
rename from speechx/cmake/external/boost.cmake
rename to speechx/cmake/boost.cmake
diff --git a/speechx/cmake/external/eigen.cmake b/speechx/cmake/eigen.cmake
similarity index 100%
rename from speechx/cmake/external/eigen.cmake
rename to speechx/cmake/eigen.cmake
diff --git a/speechx/cmake/external/gflags.cmake b/speechx/cmake/external/gflags.cmake
deleted file mode 100644
index 66ae47f70..000000000
--- a/speechx/cmake/external/gflags.cmake
+++ /dev/null
@@ -1,12 +0,0 @@
-include(FetchContent)
-
-FetchContent_Declare(
- gflags
- URL https://github.com/gflags/gflags/archive/v2.2.1.zip
- URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
-)
-
-FetchContent_MakeAvailable(gflags)
-
-# openfst need
-include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
diff --git a/speechx/cmake/gflags.cmake b/speechx/cmake/gflags.cmake
new file mode 100644
index 000000000..36bebc877
--- /dev/null
+++ b/speechx/cmake/gflags.cmake
@@ -0,0 +1,11 @@
+include(FetchContent)
+
+FetchContent_Declare(
+ gflags
+ URL https://github.com/gflags/gflags/archive/v2.2.2.zip
+ URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
+)
+FetchContent_MakeAvailable(gflags)
+
+# openfst need
+include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
diff --git a/speechx/cmake/external/glog.cmake b/speechx/cmake/glog.cmake
similarity index 100%
rename from speechx/cmake/external/glog.cmake
rename to speechx/cmake/glog.cmake
diff --git a/speechx/cmake/external/gtest.cmake b/speechx/cmake/gtest.cmake
similarity index 69%
rename from speechx/cmake/external/gtest.cmake
rename to speechx/cmake/gtest.cmake
index 7fe397fcb..1ea8ed0b7 100644
--- a/speechx/cmake/external/gtest.cmake
+++ b/speechx/cmake/gtest.cmake
@@ -1,8 +1,8 @@
include(FetchContent)
FetchContent_Declare(
gtest
- URL https://github.com/google/googletest/archive/release-1.10.0.zip
- URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
+ URL https://github.com/google/googletest/archive/release-1.11.0.zip
+ URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
FetchContent_MakeAvailable(gtest)
diff --git a/speechx/cmake/external/kenlm.cmake b/speechx/cmake/kenlm.cmake
similarity index 100%
rename from speechx/cmake/external/kenlm.cmake
rename to speechx/cmake/kenlm.cmake
diff --git a/speechx/cmake/external/libsndfile.cmake b/speechx/cmake/libsndfile.cmake
similarity index 100%
rename from speechx/cmake/external/libsndfile.cmake
rename to speechx/cmake/libsndfile.cmake
diff --git a/speechx/cmake/external/openblas.cmake b/speechx/cmake/openblas.cmake
similarity index 88%
rename from speechx/cmake/external/openblas.cmake
rename to speechx/cmake/openblas.cmake
index 5c196527e..27e132075 100644
--- a/speechx/cmake/external/openblas.cmake
+++ b/speechx/cmake/openblas.cmake
@@ -1,7 +1,7 @@
include(FetchContent)
-set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
-set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix)
+set(OpenBLAS_SOURCE_DIR ${fc_patch}/openblas-src)
+set(OpenBLAS_PREFIX ${fc_patch}/openblas-prefix)
# ######################################################################################################################
# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
@@ -43,6 +43,7 @@ ExternalProject_Add(
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
+message(STATUS "OPENBLAS install dir: ${INSTALL_DIR}")
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
add_library(openblas STATIC IMPORTED)
add_dependencies(openblas OPENBLAS)
@@ -55,4 +56,6 @@ set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_P
# ${CMAKE_INSTALL_LIBDIR} lib
# ${CMAKE_INSTALL_INCLUDEDIR} include
link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
-include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
\ No newline at end of file
+# include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
+# fix for can not find `cblas.h`
+include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/openblas)
\ No newline at end of file
diff --git a/speechx/cmake/external/openfst.cmake b/speechx/cmake/openfst.cmake
similarity index 100%
rename from speechx/cmake/external/openfst.cmake
rename to speechx/cmake/openfst.cmake
diff --git a/speechx/cmake/paddleinference.cmake b/speechx/cmake/paddleinference.cmake
new file mode 100644
index 000000000..d8a9c6134
--- /dev/null
+++ b/speechx/cmake/paddleinference.cmake
@@ -0,0 +1,49 @@
+set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
+set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
+
+include(FetchContent)
+FetchContent_Declare(
+ paddle
+ URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
+ URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
+ PREFIX ${paddle_PREFIX_DIR}
+ SOURCE_DIR ${paddle_SOURCE_DIR}
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+)
+FetchContent_MakeAvailable(paddle)
+
+set(PADDLE_LIB_THIRD_PARTY_PATH "${paddle_SOURCE_DIR}/third_party/install/")
+
+include_directories("${paddle_SOURCE_DIR}/paddle/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
+
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
+link_directories("${paddle_SOURCE_DIR}/paddle/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/lib")
+
+##paddle with mkl
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
+set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
+include_directories("${MATH_LIB_PATH}/include")
+set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+
+set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
+include_directories("${MKLDNN_PATH}/include")
+set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+
+# global vars
+set(DEPS ${paddle_SOURCE_DIR}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX} CACHE INTERNAL "deps")
+set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf xxhash cryptopp
+ ${EXTERNAL_LIB} CACHE INTERNAL "deps")
+message(STATUS "Deps libraries: ${DEPS}")
diff --git a/speechx/examples/README.md b/speechx/examples/README.md
index f7f6f9ac0..de27bd94b 100644
--- a/speechx/examples/README.md
+++ b/speechx/examples/README.md
@@ -1,20 +1,42 @@
# Examples for SpeechX
+> `u2pp_ol` is recommended.
+
+* `u2pp_ol` - u2++ streaming asr test under `aishell-1` test dataset.
* `ds2_ol` - ds2 streaming test under `aishell-1` test dataset.
+
## How to run
-`run.sh` is the entry point.
+### Create env
+
+Using `tools/evn.sh` under `speechx` to create python env.
+
+```
+bash tools/env.sh
+```
+
+Source env before play with example.
+```
+. venv/bin/activate
+```
+
+### Play with example
+
+`run.sh` is the entry point for every example.
-Example to play `ds2_ol`:
+Example to play `u2pp_ol`:
```
-pushd ds2_ol/aishell
-bash run.sh
+pushd u2pp_ol/wenetspeech
+bash run.sh --stop_stage 4
```
## Display Model with [Netron](https://github.com/lutzroeder/netron)
+If you have a model, we can using this commnd to show model graph.
+
+For example:
```
pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
diff --git a/speechx/examples/codelab/README.md b/speechx/examples/codelab/README.md
index f89184de9..803f25fac 100644
--- a/speechx/examples/codelab/README.md
+++ b/speechx/examples/codelab/README.md
@@ -1,8 +1,9 @@
# Codelab
-## introduction
+> The below is for developing and offline testing.
+> Do not run it only if you know what it is.
-> The below is for developing and offline testing. Do not run it only if you know what it is.
* nnet
* feat
* decoder
+* u2
diff --git a/speechx/examples/codelab/decoder/run.sh b/speechx/examples/codelab/decoder/run.sh
index a911eb033..1a9e3cd7e 100755
--- a/speechx/examples/codelab/decoder/run.sh
+++ b/speechx/examples/codelab/decoder/run.sh
@@ -69,7 +69,7 @@ compute_linear_spectrogram_main \
echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming
-ctc_prefix_beam_search_decoder_main \
+ctc_beam_search_decoder_main \
--result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
diff --git a/speechx/examples/codelab/feat/.gitignore b/speechx/examples/codelab/feat/.gitignore
new file mode 100644
index 000000000..bbd86a25b
--- /dev/null
+++ b/speechx/examples/codelab/feat/.gitignore
@@ -0,0 +1,2 @@
+data
+exp
diff --git a/speechx/examples/codelab/feat/path.sh b/speechx/examples/codelab/feat/path.sh
index 3b89d01e9..9d2291743 100644
--- a/speechx/examples/codelab/feat/path.sh
+++ b/speechx/examples/codelab/feat/path.sh
@@ -1,12 +1,12 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
-SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
-[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
diff --git a/speechx/examples/codelab/feat/run.sh b/speechx/examples/codelab/feat/run.sh
index 1fa37f981..5d7612ae5 100755
--- a/speechx/examples/codelab/feat/run.sh
+++ b/speechx/examples/codelab/feat/run.sh
@@ -42,8 +42,8 @@ mkdir -p $exp_dir
export GLOG_logtostderr=1
cmvn_json2kaldi_main \
- --json_file $model_dir/data/mean_std.json \
- --cmvn_write_path $exp_dir/cmvn.ark \
+ --json_file=$model_dir/data/mean_std.json \
+ --cmvn_write_path=$exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
@@ -54,4 +54,10 @@ compute_linear_spectrogram_main \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
+compute_fbank_main \
+ --num_bins=161 \
+ --wav_rspecifier=scp:$data_dir/wav.scp \
+ --feature_wspecifier=ark,t:$exp_dir/fbank.ark \
+ --cmvn_file=$exp_dir/cmvn.ark
+echo "compute fbank feature."
diff --git a/speechx/examples/codelab/nnet/path.sh b/speechx/examples/codelab/nnet/path.sh
index 7d395d648..11c8aef8b 100644
--- a/speechx/examples/codelab/nnet/path.sh
+++ b/speechx/examples/codelab/nnet/path.sh
@@ -6,7 +6,7 @@ SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
-[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
diff --git a/speechx/examples/codelab/u2/.gitignore b/speechx/examples/codelab/u2/.gitignore
new file mode 100644
index 000000000..1269488f7
--- /dev/null
+++ b/speechx/examples/codelab/u2/.gitignore
@@ -0,0 +1 @@
+data
diff --git a/speechx/examples/codelab/u2/README.md b/speechx/examples/codelab/u2/README.md
new file mode 100644
index 000000000..3c85dc917
--- /dev/null
+++ b/speechx/examples/codelab/u2/README.md
@@ -0,0 +1 @@
+# u2/u2pp Streaming Test
diff --git a/speechx/examples/codelab/u2/local/decode.sh b/speechx/examples/codelab/u2/local/decode.sh
new file mode 100755
index 000000000..11c1afe86
--- /dev/null
+++ b/speechx/examples/codelab/u2/local/decode.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=$data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+ctc_prefix_beam_search_decoder_main \
+ --model_path=$model_dir/export.jit \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --vocab_path=$model_dir/unit.txt \
+ --feature_rspecifier=ark,t:$exp/fbank.ark \
+ --result_wspecifier=ark,t:$exp/result.ark
+
+echo "u2 ctc prefix beam search decode."
diff --git a/speechx/examples/codelab/u2/local/feat.sh b/speechx/examples/codelab/u2/local/feat.sh
new file mode 100755
index 000000000..1eec3aae3
--- /dev/null
+++ b/speechx/examples/codelab/u2/local/feat.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+set -x
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+
+cmvn_json2kaldi_main \
+ --json_file $model_dir/mean_std.json \
+ --cmvn_write_path $exp/cmvn.ark \
+ --binary=false
+
+echo "convert json cmvn to kaldi ark."
+
+compute_fbank_main \
+ --num_bins 80 \
+ --wav_rspecifier=scp:$data/wav.scp \
+ --cmvn_file=$exp/cmvn.ark \
+ --feature_wspecifier=ark,t:$exp/fbank.ark
+
+echo "compute fbank feature."
diff --git a/speechx/examples/codelab/u2/local/nnet.sh b/speechx/examples/codelab/u2/local/nnet.sh
new file mode 100755
index 000000000..4419201cf
--- /dev/null
+++ b/speechx/examples/codelab/u2/local/nnet.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+set -x
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+u2_nnet_main \
+ --model_path=$model_dir/export.jit \
+ --feature_rspecifier=ark,t:$exp/fbank.ark \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --acoustic_scale=1.0 \
+ --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
+ --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
+echo "u2 nnet decode."
+
diff --git a/speechx/examples/codelab/u2/local/recognizer.sh b/speechx/examples/codelab/u2/local/recognizer.sh
new file mode 100755
index 000000000..9f697b459
--- /dev/null
+++ b/speechx/examples/codelab/u2/local/recognizer.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+u2_recognizer_main \
+ --use_fbank=true \
+ --num_bins=80 \
+ --cmvn_file=$exp/cmvn.ark \
+ --model_path=$model_dir/export.jit \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --vocab_path=$model_dir/unit.txt \
+ --wav_rspecifier=scp:$data/wav.scp \
+ --result_wspecifier=ark,t:$exp/result.ark
diff --git a/speechx/examples/codelab/u2/path.sh b/speechx/examples/codelab/u2/path.sh
new file mode 100644
index 000000000..ec278bd3d
--- /dev/null
+++ b/speechx/examples/codelab/u2/path.sh
@@ -0,0 +1,18 @@
+# This contains the locations of binarys build required for running the examples.
+
+unset GREP_OPTIONS
+
+SPEECHX_ROOT=$PWD/../../../
+SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
+
+PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
+export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
diff --git a/speechx/examples/codelab/u2/run.sh b/speechx/examples/codelab/u2/run.sh
new file mode 100755
index 000000000..d314262ba
--- /dev/null
+++ b/speechx/examples/codelab/u2/run.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+set -x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+# 2. download model
+if [ ! -f data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
+ mkdir -p data/model
+ pushd data/model
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
+ tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
+ popd
+fi
+
+# produce wav scp
+if [ ! -f data/wav.scp ]; then
+ mkdir -p data
+ pushd data
+ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+ echo "utt1 " $PWD/zh.wav > wav.scp
+ popd
+fi
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+
+./local/feat.sh
+
+./local/nnet.sh
+
+./local/decode.sh
diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh
index 82e889ce5..794b533ff 100755
--- a/speechx/examples/ds2_ol/aishell/run.sh
+++ b/speechx/examples/ds2_ol/aishell/run.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-set +x
+set -x
set -e
. path.sh
@@ -11,7 +11,7 @@ stop_stage=100
. utils/parse_options.sh
# 1. compile
-if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
@@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
- ctc_prefix_beam_search_decoder_main \
+ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
@@ -103,7 +103,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
- ctc_prefix_beam_search_decoder_main \
+ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
@@ -135,7 +135,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
- tlg_decoder_main \
+ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
diff --git a/speechx/examples/ds2_ol/aishell/run_fbank.sh b/speechx/examples/ds2_ol/aishell/run_fbank.sh
index 720728354..1c3c3e010 100755
--- a/speechx/examples/ds2_ol/aishell/run_fbank.sh
+++ b/speechx/examples/ds2_ol/aishell/run_fbank.sh
@@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
- ctc_prefix_beam_search_decoder_main \
+ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
@@ -102,7 +102,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
- ctc_prefix_beam_search_decoder_main \
+ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
@@ -133,7 +133,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
- tlg_decoder_main \
+ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \
diff --git a/speechx/examples/u2pp_ol/README.md b/speechx/examples/u2pp_ol/README.md
new file mode 100644
index 000000000..838db435c
--- /dev/null
+++ b/speechx/examples/u2pp_ol/README.md
@@ -0,0 +1,5 @@
+# U2/U2++ Streaming ASR
+
+## Examples
+
+* `wenetspeech` - Streaming Decoding with wenetspeech u2/u2++ model. Using aishell test data for testing.
diff --git a/speechx/examples/u2pp_ol/wenetspeech/.gitignore b/speechx/examples/u2pp_ol/wenetspeech/.gitignore
new file mode 100644
index 000000000..02c0cc21f
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/.gitignore
@@ -0,0 +1,3 @@
+data
+utils
+exp
diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/speechx/examples/u2pp_ol/wenetspeech/README.md
new file mode 100644
index 000000000..9a8f8af51
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/README.md
@@ -0,0 +1,28 @@
+# u2/u2pp Streaming ASR
+
+## Testing with Aishell Test Data
+
+## Download wav and model
+
+```
+run.sh --stop_stage 0
+```
+
+### compute feature
+
+```
+./run.sh --stage 1 --stop_stage 1
+```
+
+### decoding using feature
+
+```
+./run.sh --stage 2 --stop_stage 2
+```
+
+### decoding using wav
+
+
+```
+./run.sh --stage 3 --stop_stage 3
+```
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh b/speechx/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh
new file mode 100755
index 000000000..544a1f59a
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+
+# To be run from one directory above this script.
+. ./path.sh
+
+nj=40
+text=data/local/lm/text
+lexicon=data/local/dict/lexicon.txt
+
+for f in "$text" "$lexicon"; do
+ [ ! -f $x ] && echo "$0: No such file $f" && exit 1;
+done
+
+# Check SRILM tools
+if ! which ngram-count > /dev/null; then
+ echo "srilm tools are not found, please download it and install it from: "
+ echo "http://www.speech.sri.com/projects/srilm/download.html"
+ echo "Then add the tools to your PATH"
+ exit 1
+fi
+
+# This script takes no arguments. It assumes you have already run
+# aishell_data_prep.sh.
+# It takes as input the files
+# data/local/lm/text
+# data/local/dict/lexicon.txt
+dir=data/local/lm
+mkdir -p $dir
+
+cleantext=$dir/text.no_oov
+
+# oov to
+# lexicon line: word char0 ... charn
+# text line: utt word0 ... wordn -> line: word0 ... wordn
+text_dir=$(dirname $text)
+split_name=$(basename $text)
+./local/split_data.sh $text_dir $text $split_name $nj
+
+utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
+ cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } }
+ {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \
+ \> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
+cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
+
+# compute word counts, sort in descending order
+# line: count word
+cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
+ sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
+
+# Get counts from acoustic training transcripts, and add one-count
+# for each word in the lexicon (but not silence, we don't want it
+# in the LM-- we'll add it optionally later).
+cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
+ cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
+ sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
+
+# word with
+cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist
+
+# hold out to compute ppl
+heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
+
+mkdir -p $dir
+cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout
+cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train
+
+ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
+ -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa
+ngram -lm $dir/lm.arpa -ppl $dir/heldout
\ No newline at end of file
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh b/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh
new file mode 100755
index 000000000..c17cdbe65
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+nj=20
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
+ctc_prefix_beam_search_decoder_main \
+ --model_path=$model_dir/export.jit \
+ --vocab_path=$model_dir/unit.txt \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank.scp \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_decode.ark
+
+cat $data/split${nj}/*/result_decode.ark > $exp/${label_file}
+utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
+tail -n 7 $exp/${wer}
\ No newline at end of file
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh
new file mode 100755
index 000000000..4341cec8b
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+nj=20
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+aishell_wav_scp=aishell_test.scp
+
+cmvn_json2kaldi_main \
+ --json_file $model_dir/mean_std.json \
+ --cmvn_write_path $exp/cmvn.ark \
+ --binary=false
+
+echo "convert json cmvn to kaldi ark."
+
+./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
+
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
+compute_fbank_main \
+ --num_bins 80 \
+ --cmvn_file=$exp/cmvn.ark \
+ --streaming_chunk=36 \
+ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
+ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
+
+echo "compute fbank feature."
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh b/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh
new file mode 100755
index 000000000..4419201cf
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+set -x
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+u2_nnet_main \
+ --model_path=$model_dir/export.jit \
+ --feature_rspecifier=ark,t:$exp/fbank.ark \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --acoustic_scale=1.0 \
+ --nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
+ --nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
+echo "u2 nnet decode."
+
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh
new file mode 100755
index 000000000..f4553f2ab
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+set -e
+
+. path.sh
+
+data=data
+exp=exp
+nj=20
+
+
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+aishell_wav_scp=aishell_test.scp
+text=$data/test/text
+
+./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
+
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
+u2_recognizer_main \
+ --use_fbank=true \
+ --num_bins=80 \
+ --cmvn_file=$exp/cmvn.ark \
+ --model_path=$model_dir/export.jit \
+ --vocab_path=$model_dir/unit.txt \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer.ark
+
+
+cat $data/split${nj}/*/result_recognizer.ark > $exp/aishell_recognizer
+utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer > $exp/aishell.recognizer.err
+echo "recognizer test have finished!!!"
+echo "please checkout in $exp/aishell.recognizer.err"
+tail -n 7 $exp/aishell.recognizer.err
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/split_data.sh b/speechx/examples/u2pp_ol/wenetspeech/local/split_data.sh
new file mode 100755
index 000000000..faa5c42dc
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/split_data.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+set -eo pipefail
+
+data=$1
+scp=$2
+split_name=$3
+numsplit=$4
+
+# save in $data/split{n}
+# $scp to split
+#
+
+if [[ ! $numsplit -gt 0 ]]; then
+ echo "$0: Invalid num-split argument";
+ exit 1;
+fi
+
+directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
+scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
+
+# if this mkdir fails due to argument-list being too long, iterate.
+if ! mkdir -p $directories >&/dev/null; then
+ for n in `seq $numsplit`; do
+ mkdir -p $data/split${numsplit}/$n
+ done
+fi
+
+echo "utils/split_scp.pl $scp $scp_splits"
+utils/split_scp.pl $scp $scp_splits
diff --git a/speechx/examples/u2pp_ol/wenetspeech/path.sh b/speechx/examples/u2pp_ol/wenetspeech/path.sh
new file mode 100644
index 000000000..ec278bd3d
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/path.sh
@@ -0,0 +1,18 @@
+# This contains the locations of binarys build required for running the examples.
+
+unset GREP_OPTIONS
+
+SPEECHX_ROOT=$PWD/../../../
+SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
+
+PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
+export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH
diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh
new file mode 100755
index 000000000..12e3af95a
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+nj=40
+stage=0
+stop_stage=5
+
+. utils/parse_options.sh
+
+# input
+data=data
+exp=exp
+mkdir -p $exp $data
+
+
+# 1. compile
+if [ ! -d ${SPEECHX_BUILD} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+
+ckpt_dir=$data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
+ # download model
+ if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
+ mkdir -p $ckpt_dir
+ pushd $ckpt_dir
+
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
+ tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
+
+ popd
+ fi
+
+ # test wav scp
+ if [ ! -f data/wav.scp ]; then
+ mkdir -p $data
+ pushd $data
+ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+ echo "utt1 " $PWD/zh.wav > wav.scp
+ popd
+ fi
+
+ # aishell wav scp
+ if [ ! -d $data/test ]; then
+ pushd $data
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
+ unzip aishell_test.zip
+ popd
+
+ realpath $data/test/*/*.wav > $data/wavlist
+ awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
+ paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
+ fi
+fi
+
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ ./local/feat.sh
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ ./local/decode.sh
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ ./loca/recognizer.sh
+fi
\ No newline at end of file
diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt
index c8e21d486..60c183472 100644
--- a/speechx/speechx/CMakeLists.txt
+++ b/speechx/speechx/CMakeLists.txt
@@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
+include_directories(
+${CMAKE_CURRENT_SOURCE_DIR}
+${CMAKE_CURRENT_SOURCE_DIR}/recognizer
+)
+add_subdirectory(recognizer)
+
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/protocol
diff --git a/speechx/speechx/base/basic_types.h b/speechx/speechx/base/basic_types.h
index 206b7be67..2b15a61fe 100644
--- a/speechx/speechx/base/basic_types.h
+++ b/speechx/speechx/base/basic_types.h
@@ -14,47 +14,47 @@
#pragma once
-#include "kaldi/base/kaldi-types.h"
-
#include
+#include "kaldi/base/kaldi-types.h"
+
typedef float BaseFloat;
typedef double double64;
typedef signed char int8;
-typedef short int16;
-typedef int int32;
+typedef short int16; // NOLINT
+typedef int int32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
-typedef long int64;
+typedef long int64; // NOLINT
#else
-typedef long long int64;
+typedef long long int64; // NOLINT
#endif
-typedef unsigned char uint8;
-typedef unsigned short uint16;
-typedef unsigned int uint32;
+typedef unsigned char uint8; // NOLINT
+typedef unsigned short uint16; // NOLINT
+typedef unsigned int uint32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
-typedef unsigned long uint64;
+typedef unsigned long uint64; // NOLINT
#else
-typedef unsigned long long uint64;
+typedef unsigned long long uint64; // NOLINT
#endif
typedef signed int char32;
-const uint8 kuint8max = ((uint8)0xFF);
-const uint16 kuint16max = ((uint16)0xFFFF);
-const uint32 kuint32max = ((uint32)0xFFFFFFFF);
-const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
-const int8 kint8min = ((int8)0x80);
-const int8 kint8max = ((int8)0x7F);
-const int16 kint16min = ((int16)0x8000);
-const int16 kint16max = ((int16)0x7FFF);
-const int32 kint32min = ((int32)0x80000000);
-const int32 kint32max = ((int32)0x7FFFFFFF);
-const int64 kint64min = ((int64)(0x8000000000000000LL));
-const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
+const uint8 kuint8max = static_cast(0xFF);
+const uint16 kuint16max = static_cast(0xFFFF);
+const uint32 kuint32max = static_cast(0xFFFFFFFF);
+const uint64 kuint64max = static_cast(0xFFFFFFFFFFFFFFFFLL);
+const int8 kint8min = static_cast(0x80);
+const int8 kint8max = static_cast(0x7F);
+const int16 kint16min = static_cast(0x8000);
+const int16 kint16max = static_cast(0x7FFF);
+const int32 kint32min = static_cast(0x80000000);
+const int32 kint32max = static_cast(0x7FFFFFFF);
+const int64 kint64min = static_cast(0x8000000000000000LL);
+const int64 kint64max = static_cast(0x7FFFFFFFFFFFFFFFLL);
const BaseFloat kBaseFloatMax = std::numeric_limits::max();
const BaseFloat kBaseFloatMin = std::numeric_limits::min();
diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h
index a9303cbbc..97bff9662 100644
--- a/speechx/speechx/base/common.h
+++ b/speechx/speechx/base/common.h
@@ -14,21 +14,30 @@
#pragma once
+#include
+#include
+#include
#include
+#include
#include
#include
+#include
#include
#include
#include