diff --git a/README.md b/README.md
index 26f13d00e..9d7ed4258 100644
--- a/README.md
+++ b/README.md
@@ -157,6 +157,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
+- 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS.
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web).
@@ -698,6 +699,31 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
+
+
+**Keyword Spotting**
+
+
+
+
+ Task |
+ Dataset |
+ Model Type |
+ Example |
+
+
+
+
+ Keyword Spotting |
+ hey-snips |
+ PANN |
+
+ pann-hey-snips
+ |
+
+
+
+
**Speaker Verification**
@@ -827,7 +853,21 @@ The Text-to-Speech module is originally called [Parakeet](https://github.com/Pad
## Citation
To cite PaddleSpeech for research, please use the following format.
-```tex
+```text
+@InProceedings{pmlr-v162-bai22d,
+ title = {{A}$^3${T}: Alignment-Aware Acoustic and Text Pretraining for Speech Synthesis and Editing},
+ author = {Bai, He and Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Li, Xintong and Huang, Liang},
+ booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+ pages = {1399--1411},
+ year = {2022},
+ volume = {162},
+ series = {Proceedings of Machine Learning Research},
+ month = {17--23 Jul},
+ publisher = {PMLR},
+ pdf = {https://proceedings.mlr.press/v162/bai22d/bai22d.pdf},
+ url = {https://proceedings.mlr.press/v162/bai22d.html},
+}
+
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
diff --git a/README_cn.md b/README_cn.md
index 9a4549898..2db883b5a 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -164,7 +164,8 @@
### 近期更新
- - 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
+- 🔥 2022.10.26: TTS 新增[韵律预测](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy)功能。
+- 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
@@ -695,6 +696,31 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
+
+
+**唤醒**
+
+
+
+
+ 任务 |
+ 数据集 |
+ 模型类型 |
+ 脚本 |
+
+
+
+
+ 唤醒 |
+ hey-snips |
+ PANN |
+
+ pann-hey-snips
+ |
+
+
+
+
**声纹识别**
@@ -833,6 +859,20 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
要引用 PaddleSpeech 进行研究,请使用以下格式进行引用。
```text
+@InProceedings{pmlr-v162-bai22d,
+ title = {{A}$^3${T}: Alignment-Aware Acoustic and Text Pretraining for Speech Synthesis and Editing},
+ author = {Bai, He and Zheng, Renjie and Chen, Junkun and Ma, Mingbo and Li, Xintong and Huang, Liang},
+ booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+ pages = {1399--1411},
+ year = {2022},
+ volume = {162},
+ series = {Proceedings of Machine Learning Research},
+ month = {17--23 Jul},
+ publisher = {PMLR},
+ pdf = {https://proceedings.mlr.press/v162/bai22d/bai22d.pdf},
+ url = {https://proceedings.mlr.press/v162/bai22d.html},
+}
+
@inproceedings{zhang2022paddlespeech,
title = {PaddleSpeech: An Easy-to-Use All-in-One Speech Toolkit},
author = {Hui Zhang, Tian Yuan, Junkun Chen, Xintong Li, Renjie Zheng, Yuxin Huang, Xiaojie Chen, Enlei Gong, Zeyu Chen, Xiaoguang Hu, dianhai yu, Yanjun Ma, Liang Huang},
diff --git a/docs/source/install.md b/docs/source/install.md
index 1e6c1c48b..20d7895df 100644
--- a/docs/source/install.md
+++ b/docs/source/install.md
@@ -188,10 +188,6 @@ conda activate tools/venv
conda install -y -c conda-forge sox libsndfile swig bzip2 libflac bc
```
### Install PaddlePaddle
-Some users may fail to install `kaldiio` due to the default download source, you can install `pytest-runner` at first;
-```bash
-pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple
-```
Make sure you have GPU and the paddlepaddle version is right. For example, for CUDA 10.2, CuDNN7.6 install paddle 2.4rc:
```bash
# Note, 2.4rc is just an example, please follow the minimum dependency of paddlepaddle for your selection
@@ -202,6 +198,11 @@ You can also install the develop version of paddlepaddle. For example, for CUDA
python3 -m pip install paddlepaddle-gpu==0.0.0.post102 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
```
### Install PaddleSpeech in Developing Mode
+Some users may fail to install `kaldiio` due to the default download source, you can install `pytest-runner` at first:
+```bash
+pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple
+```
+Then install PaddleSpeech:
```bash
pip install -e .[develop] -i https://pypi.tuna.tsinghua.edu.cn/simple
```
diff --git a/docs/source/install_cn.md b/docs/source/install_cn.md
index ebc0cf7a2..dd06946f3 100644
--- a/docs/source/install_cn.md
+++ b/docs/source/install_cn.md
@@ -182,6 +182,7 @@ conda install -y -c conda-forge sox libsndfile swig bzip2 libflac bc
### 安装 PaddlePaddle
请确认你系统是否有 GPU,并且使用了正确版本的 paddlepaddle。例如系统使用 CUDA 10.2, CuDNN7.6 ,你可以安装 paddlepaddle-gpu 2.4rc:
```bash
+# 注意:2.4rc 只是一个示例,请按照对paddlepaddle的最小依赖进行选择。
python3 -m pip install paddlepaddle-gpu==2.4.0rc0 -i https://mirror.baidu.com/pypi/simple
```
你也可以安装 develop 版本的PaddlePaddle. 例如系统使用 CUDA 10.2, CuDNN7.6 ,你可以安装 paddlepaddle-gpu develop:
@@ -191,7 +192,6 @@ python3 -m pip install paddlepaddle-gpu==0.0.0.post102 -f https://www.paddlepadd
### 用开发者模式安装 PaddleSpeech
部分用户系统由于默认源的问题,安装中会出现 kaldiio 安转出错的问题,建议首先安装 pytest-runner:
```bash
-# 注意:2.4rc 只是一个示例,请按照对paddlepaddle的最小依赖进行选择。
pip install pytest-runner -i https://pypi.tuna.tsinghua.edu.cn/simple
```
然后安装 PaddleSpeech:
diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md
index ba263906a..6311ad928 100644
--- a/examples/hey_snips/README.md
+++ b/examples/hey_snips/README.md
@@ -2,7 +2,7 @@
## Metrics
We mesure FRRs with fixing false alarms in one hour:
-
+the release model: https://paddlespeech.bj.bcebos.com/kws/heysnips/kws0_mdtc_heysnips_ckpt.tar.gz
|Model|False Alarm| False Reject Rate|
|--|--|--|
|MDTC| 1| 0.003559 |
diff --git a/examples/other/rhy/README.md b/examples/other/rhy/README.md
new file mode 100644
index 000000000..11336ad9f
--- /dev/null
+++ b/examples/other/rhy/README.md
@@ -0,0 +1,41 @@
+# Prosody Prediction with CSMSC and AISHELL-3
+
+## Get Started
+### Data Preprocessing
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Model Training
+```bash
+./run.sh --stage 1 --stop-stage 1
+```
+### Testing
+```bash
+./run.sh --stage 2 --stop-stage 2
+```
+### Prosody Prediction
+```bash
+./run.sh --stage 3 --stop-stage 3
+```
+## Pretrained Model
+The pretrained model can be downloaded here:
+
+[ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/ernie-1.0_aishellcsmsc_ckpt_1.3.0.zip)
+
+And you should put it into `exp/${YOUREXP}/checkpoints` folder.
+
+## Rhythm mapping
+Four punctuation marks are used to denote the rhythm marks respectively:
+|ryh_token|csmsc|aishll3|
+|:---: |:---: |:---: |
+|%|#1|%|
+|`|#2||
+|~|#3||
+|$|#4|$|
+
+## Prediction Results
+| | #1 | #2 | #3 | #4 |
+|:-----:|:-----:|:-----:|:-----:|:-----:|
+|Precision |0.90 |0.66 |0.91 |0.90|
+|Recall |0.92 |0.62 |0.83 |0.85|
+|F1 |0.91 |0.64 |0.87 |0.87|
diff --git a/examples/other/rhy/conf/default.yaml b/examples/other/rhy/conf/default.yaml
new file mode 100644
index 000000000..1eb90f11f
--- /dev/null
+++ b/examples/other/rhy/conf/default.yaml
@@ -0,0 +1,44 @@
+###########################################################
+# DATA SETTING #
+###########################################################
+dataset_type: Ernie
+train_path: data/train.txt
+dev_path: data/dev.txt
+test_path: data/test.txt
+batch_size: 64
+num_workers: 2
+data_params:
+ pretrained_token: ernie-1.0
+ punc_path: data/rhy_token
+ seq_len: 100
+
+
+###########################################################
+# MODEL SETTING #
+###########################################################
+model_type: ErnieLinear
+model:
+ pretrained_token: ernie-1.0
+ num_classes: 5
+
+###########################################################
+# OPTIMIZER SETTING #
+###########################################################
+optimizer_params:
+ weight_decay: 1.0e-6 # weight decay coefficient.
+
+scheduler_params:
+ learning_rate: 1.0e-5 # learning rate.
+ gamma: 0.9999 # scheduler gamma must between(0.0, 1.0) and closer to 1.0 is better.
+
+###########################################################
+# TRAINING SETTING #
+###########################################################
+max_epoch: 20
+num_snapshots: 5
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/other/rhy/data/rhy_token b/examples/other/rhy/data/rhy_token
new file mode 100644
index 000000000..bf1fe253f
--- /dev/null
+++ b/examples/other/rhy/data/rhy_token
@@ -0,0 +1,4 @@
+%
+`
+~
+$
\ No newline at end of file
diff --git a/examples/other/rhy/local/data.sh b/examples/other/rhy/local/data.sh
new file mode 100755
index 000000000..93b134873
--- /dev/null
+++ b/examples/other/rhy/local/data.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+if [ ! -f 000001-010000.txt ]; then
+ wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/000001-010000.txt
+fi
+
+if [ ! -f label_train-set.txt ]; then
+ wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/rhy_predict/label_train-set.txt
+fi
+
+
+aishell_data=$1
+csmsc_data=$2
+processed_path=$3
+
+python3 ./local/pre_for_sp_csmsc.py \
+ --data=${csmsc_data} \
+ --processed_path=${processed_path}
+
+python3 ./local/pre_for_sp_aishell.py \
+ --data=${aishell_data} \
+ --processed_path=${processed_path}
+
+
+echo "Finish data preparation."
+exit 0
diff --git a/examples/other/rhy/local/pre_for_sp_aishell.py b/examples/other/rhy/local/pre_for_sp_aishell.py
new file mode 100644
index 000000000..a2a716683
--- /dev/null
+++ b/examples/other/rhy/local/pre_for_sp_aishell.py
@@ -0,0 +1,51 @@
+import argparse
+import os
+import re
+
+# This is the replacement for rhythm labels to predict.
+# 韵律标签的代替
+replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
+
+
+def replace_rhy_with_punc(line):
+ # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
+ line = re.sub(r'[:、,;。?!,.:;"?!’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line)
+ for r in replace_.keys():
+ if r in line:
+ line = line.replace(r, replace_[r])
+ return line
+
+
+def pre_and_write(data, file):
+ with open(file, 'a') as rf:
+ for d in data:
+ d = d.split('|')[2].strip()
+ # d = replace_rhy_with_punc(d)
+ d = ' '.join(d) + ' \n'
+ rf.write(d)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Train a Rhy prediction model.")
+ parser.add_argument("--data", type=str, default="label_train-set.txt")
+ parser.add_argument(
+ "--processed_path", type=str, default="../data/rhy_predict")
+ args = parser.parse_args()
+ os.makedirs(args.processed_path, exist_ok=True)
+
+ with open(args.data) as rf:
+ text = rf.readlines()[5:]
+ len_ = len(text)
+ lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
+ files = ['train.txt', 'test.txt', 'dev.txt']
+
+ i = 0
+ for l_, file in zip(lens, files):
+ file = os.path.join(args.processed_path, file)
+ pre_and_write(text[i:i + l_], file)
+ i = i + l_
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/other/rhy/local/pre_for_sp_csmsc.py b/examples/other/rhy/local/pre_for_sp_csmsc.py
new file mode 100644
index 000000000..0a96092c1
--- /dev/null
+++ b/examples/other/rhy/local/pre_for_sp_csmsc.py
@@ -0,0 +1,51 @@
+import argparse
+import os
+import re
+
+replace_ = {"#1": "%", "#2": "`", "#3": "~", "#4": "$"}
+
+
+def replace_rhy_with_punc(line):
+ # r'[:、,;。?!,.:;"?!”’《》【】<=>{}()()#&@“”^_|…\\]%*$', '', line) #参考checkcheck_oov.py,
+ line = re.sub(r'^$\*%', '', line)
+ for r in replace_.keys():
+ if r in line:
+ line = line.replace(r, replace_[r])
+ return line
+
+
+def pre_and_write(data, file):
+ with open(file, 'w') as rf:
+ for d in data:
+ d = d.split('\t')[1].strip()
+ d = replace_rhy_with_punc(d)
+ d = ' '.join(d) + ' \n'
+ rf.write(d)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Train a Rhy prediction model.")
+ parser.add_argument("--data", type=str, default="label_train-set.txt")
+ parser.add_argument(
+ "--processed_path", type=str, default="../data/rhy_predict")
+ args = parser.parse_args()
+ print(args.data, args.processed_path)
+ os.makedirs(args.processed_path, exist_ok=True)
+
+ with open(args.data) as rf:
+ rf = rf.readlines()
+ text = rf[0::2]
+ len_ = len(text)
+ lens = [int(len_ * 0.9), int(len_ * 0.05), int(len_ * 0.05)]
+ files = ['train.txt', 'test.txt', 'dev.txt']
+
+ i = 0
+ for l_, file in zip(lens, files):
+ file = os.path.join(args.processed_path, file)
+ pre_and_write(text[i:i + l_], file)
+ i = i + l_
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/other/rhy/local/rhy_predict.sh b/examples/other/rhy/local/rhy_predict.sh
new file mode 100755
index 000000000..30a4f12f8
--- /dev/null
+++ b/examples/other/rhy/local/rhy_predict.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+text=$4
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/punc_restore.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --text=${text}
diff --git a/examples/other/rhy/local/test.sh b/examples/other/rhy/local/test.sh
new file mode 100755
index 000000000..bd490b5b9
--- /dev/null
+++ b/examples/other/rhy/local/test.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+print_eval=$4
+
+ckpt_prefix=${ckpt_name%.*}
+
+python3 ${BIN_DIR}/test.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --print_eval=${print_eval}
\ No newline at end of file
diff --git a/examples/other/rhy/local/train.sh b/examples/other/rhy/local/train.sh
new file mode 100755
index 000000000..85227eacb
--- /dev/null
+++ b/examples/other/rhy/local/train.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+
+python3 ${BIN_DIR}/train.py \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/other/rhy/path.sh b/examples/other/rhy/path.sh
new file mode 100755
index 000000000..da790261f
--- /dev/null
+++ b/examples/other/rhy/path.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+export MAIN_ROOT=${PWD}/../../../
+
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
+export LC_ALL=C
+
+# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
+
+export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
+
+MODEL=ernie_linear
+export BIN_DIR=${MAIN_ROOT}/paddlespeech/text/exps/${MODEL}
diff --git a/examples/other/rhy/run.sh b/examples/other/rhy/run.sh
new file mode 100755
index 000000000..aed58152e
--- /dev/null
+++ b/examples/other/rhy/run.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set -e
+source path.sh
+
+gpus=0
+stage=0
+stop_stage=100
+
+aishell_data=label_train-set.txt
+csmsc_data=000001-010000.txt
+processed_path=data
+
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_2600.pdz
+text=我们城市的复苏有赖于他强有力的政策。
+print_eval=false
+
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ ./local/data.sh ${aishell_data} ${csmsc_data} ${processed_path}
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${train_output_path} ${ckpt_name} ${print_eval} || exit -1
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/rhy_predict.sh ${conf_path} ${train_output_path} ${ckpt_name} ${text}|| exit -1
+fi
\ No newline at end of file
diff --git a/examples/zh_en_tts/tts3/README.md b/examples/zh_en_tts/tts3/README.md
index b4b683089..012028007 100644
--- a/examples/zh_en_tts/tts3/README.md
+++ b/examples/zh_en_tts/tts3/README.md
@@ -116,6 +116,8 @@ optional arguments:
5. `--phones-dict` is the path of the phone vocabulary file.
6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2.
+We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: `config["model"]["enable_speaker_classifier"]`, `config["model"]["hidden_sc_dim"]` and `config["updater"]["spk_loss_scale"]` in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself.
+
### Synthesizing
We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder.
diff --git a/examples/zh_en_tts/tts3/conf/default.yaml b/examples/zh_en_tts/tts3/conf/default.yaml
index e65b5d0ec..efa8b3ea2 100644
--- a/examples/zh_en_tts/tts3/conf/default.yaml
+++ b/examples/zh_en_tts/tts3/conf/default.yaml
@@ -74,6 +74,9 @@ model:
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
spk_embed_dim: 256 # speaker embedding dimension
spk_embed_integration_type: concat # speaker embedding integration type
+ enable_speaker_classifier: False # Whether to use speaker classifier module
+ hidden_sc_dim: 256 # The hidden layer dim of speaker classifier
+
@@ -82,6 +85,7 @@ model:
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
+ spk_loss_scale: 0.02 # The scales of speaker classifier loss
###########################################################
diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test.py b/paddlespeech/s2t/exps/wav2vec2/bin/test.py
index d1a6fd405..a376651df 100644
--- a/paddlespeech/s2t/exps/wav2vec2/bin/test.py
+++ b/paddlespeech/s2t/exps/wav2vec2/bin/test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py b/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
index 3a537bce5..0d66ac410 100644
--- a/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
+++ b/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py
index b2edecca1..3ae3a9e73 100644
--- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py
+++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py
index 16feac5de..933e268ed 100644
--- a/paddlespeech/s2t/exps/wav2vec2/model.py
+++ b/paddlespeech/s2t/exps/wav2vec2/model.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
index ae141d1b3..cfd8f507e 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py
@@ -1,7 +1,19 @@
-"""Vanilla Neural Network for simple tests.
-Authors
-* Elena Rastorgueva 2020
-"""
+# Authors
+# * Elena Rastorgueva 2020
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/lobes/models/VanillaNN.py).
import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/activations.py b/paddlespeech/s2t/models/wav2vec2/modules/activations.py
index 722d8a0d6..af42b8a47 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/activations.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/activations.py
@@ -1,3 +1,4 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/containers.py b/paddlespeech/s2t/models/wav2vec2/modules/containers.py
index b39733570..180d0bd32 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/containers.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/containers.py
@@ -1,3 +1,19 @@
+# Authors
+# * Peter Plantinga 2020
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/containers.py).
import inspect
import paddle
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/linear.py b/paddlespeech/s2t/models/wav2vec2/modules/linear.py
index 488949d14..adae4514a 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/linear.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/linear.py
@@ -1,8 +1,20 @@
-"""Library implementing linear transformation.
-Authors
- * Mirco Ravanelli 2020
- * Davide Borra 2021
-"""
+# Authors
+# * Mirco Ravanelli 2020
+# * Davide Borra 2021
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/linear.py).
import logging
import paddle
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
index fb2a87122..ab623a996 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py
@@ -1,3 +1,4 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
index 3d5e5fa64..e484fff68 100644
--- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
+++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py
@@ -1,4 +1,4 @@
-# coding=utf-8
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py b/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
index 9998a8e5e..0c4ade7b7 100644
--- a/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
+++ b/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py
@@ -1,12 +1,23 @@
-"""
-Low level signal processing utilities
-Authors
- * Peter Plantinga 2020
- * Francois Grondin 2020
- * William Aris 2020
- * Samuele Cornell 2020
- * Sarthak Yadav 2022
-"""
+# Authors
+# * Peter Plantinga 2020
+# * Francois Grondin 2020
+# * William Aris 2020
+# * Samuele Cornell 2020
+# * Sarthak Yadav 2022
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/processing/signal_processing.py)
import numpy as np
import paddle
diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
index 471ab7657..78a0782e7 100644
--- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
+++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py
@@ -1,3 +1,19 @@
+# Authors
+# * Peter Plantinga 2020
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from speechbrain(https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/processing/speech_augmentation.py)
import math
import paddle
diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
index 0d99e8708..e13347740 100644
--- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
+++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py
@@ -1,3 +1,16 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from collections import defaultdict
from typing import Dict
from typing import List
diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py
index 10e023d0c..d31e62a82 100644
--- a/paddlespeech/t2s/exps/fastspeech2/train.py
+++ b/paddlespeech/t2s/exps/fastspeech2/train.py
@@ -145,17 +145,27 @@ def train_sp(args, config):
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
+ if "enable_speaker_classifier" in config.model:
+ enable_spk_cls = config.model.enable_speaker_classifier
+ else:
+ enable_spk_cls = False
+
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
- **config["updater"])
+ enable_spk_cls=enable_spk_cls,
+ **config["updater"], )
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(
- model, dev_dataloader, output_dir=output_dir, **config["updater"])
+ model,
+ dev_dataloader,
+ output_dir=output_dir,
+ enable_spk_cls=enable_spk_cls,
+ **config["updater"], )
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py
index ad32c4050..4e6fad4e5 100644
--- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py
+++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py
@@ -210,7 +210,8 @@ class G2PWOnnxConverter:
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
- pypinyin_result = pinyin(sent_s, style=Style.TONE3)
+ pypinyin_result = pinyin(
+ sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py
index 9905765db..0eb44beb6 100644
--- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py
+++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py
@@ -25,6 +25,8 @@ import paddle.nn.functional as F
from paddle import nn
from typeguard import check_argument_types
+from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer
+from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
@@ -138,7 +140,10 @@ class FastSpeech2(nn.Layer):
# training related
init_type: str="xavier_uniform",
init_enc_alpha: float=1.0,
- init_dec_alpha: float=1.0, ):
+ init_dec_alpha: float=1.0,
+ # speaker classifier
+ enable_speaker_classifier: bool=False,
+ hidden_sc_dim: int=256, ):
"""Initialize FastSpeech2 module.
Args:
idim (int):
@@ -268,6 +273,10 @@ class FastSpeech2(nn.Layer):
Initial value of alpha in scaled pos encoding of the encoder.
init_dec_alpha (float):
Initial value of alpha in scaled pos encoding of the decoder.
+ enable_speaker_classifier (bool):
+ Whether to use speaker classifier module
+ hidden_sc_dim (int):
+ The hidden layer dim of speaker classifier
"""
assert check_argument_types()
@@ -281,6 +290,9 @@ class FastSpeech2(nn.Layer):
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
self.use_scaled_pos_enc = use_scaled_pos_enc
+ self.hidden_sc_dim = hidden_sc_dim
+ self.spk_num = spk_num
+ self.enable_speaker_classifier = enable_speaker_classifier
self.spk_embed_dim = spk_embed_dim
if self.spk_embed_dim is not None:
@@ -373,6 +385,12 @@ class FastSpeech2(nn.Layer):
self.tone_projection = nn.Linear(adim + self.tone_embed_dim,
adim)
+ if self.spk_num and self.enable_speaker_classifier:
+ # set lambda = 1
+ self.grad_reverse = GradientReversalLayer(1)
+ self.speaker_classifier = SpeakerClassifier(
+ idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num)
+
# define duration predictor
self.duration_predictor = DurationPredictor(
idim=adim,
@@ -547,7 +565,7 @@ class FastSpeech2(nn.Layer):
if tone_id is not None:
tone_id = paddle.cast(tone_id, 'int64')
# forward propagation
- before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(
+ before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs,
ilens,
olens,
@@ -564,7 +582,7 @@ class FastSpeech2(nn.Layer):
max_olen = max(olens)
ys = ys[:, :max_olen]
- return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens
+ return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(self,
xs: paddle.Tensor,
@@ -584,6 +602,12 @@ class FastSpeech2(nn.Layer):
# (B, Tmax, adim)
hs, _ = self.encoder(xs, x_masks)
+ if self.spk_num and self.enable_speaker_classifier and not is_inference:
+ hs_for_spk_cls = self.grad_reverse(hs)
+ spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
+ else:
+ spk_logits = None
+
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
@@ -676,7 +700,7 @@ class FastSpeech2(nn.Layer):
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
- return before_outs, after_outs, d_outs, p_outs, e_outs
+ return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
@@ -771,7 +795,7 @@ class FastSpeech2(nn.Layer):
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
- _, outs, d_outs, p_outs, e_outs = self._forward(
+ _, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
ds=ds,
@@ -783,7 +807,7 @@ class FastSpeech2(nn.Layer):
is_inference=True)
else:
# (1, L, odim)
- _, outs, d_outs, p_outs, e_outs = self._forward(
+ _, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs,
ilens,
is_inference=True,
@@ -791,6 +815,7 @@ class FastSpeech2(nn.Layer):
spk_emb=spk_emb,
spk_id=spk_id,
tone_id=tone_id)
+
return outs[0], d_outs[0], p_outs[0], e_outs[0]
def _integrate_with_spk_embed(self, hs, spk_emb):
@@ -1058,6 +1083,7 @@ class FastSpeech2Loss(nn.Layer):
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
+ self.ce_criterion = nn.CrossEntropyLoss()
def forward(
self,
@@ -1072,7 +1098,10 @@ class FastSpeech2Loss(nn.Layer):
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
- ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
+ spk_logits: paddle.Tensor=None,
+ spk_ids: paddle.Tensor=None,
+ ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
+ paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
@@ -1098,11 +1127,18 @@ class FastSpeech2Loss(nn.Layer):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
+ spk_logits(Option[Tensor]):
+ Batch of outputs after speaker classifier (B, Lmax, num_spk)
+ spk_ids(Option[Tensor]):
+ Batch of target spk_id (B,)
+
Returns:
"""
+ speaker_loss = 0.0
+
# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
@@ -1124,6 +1160,16 @@ class FastSpeech2Loss(nn.Layer):
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
+ if spk_logits is not None and spk_ids is not None:
+ batch_size = spk_ids.shape[0]
+ spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
+ None)
+ spk_logits = paddle.reshape(spk_logits,
+ [-1, spk_logits.shape[-1]])
+ mask_index = spk_logits.abs().sum(axis=1) != 0
+ spk_ids = spk_ids[mask_index]
+ spk_logits = spk_logits[mask_index]
+
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
@@ -1132,6 +1178,9 @@ class FastSpeech2Loss(nn.Layer):
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
+ if spk_logits is not None and spk_ids is not None:
+ speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
+
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
@@ -1161,4 +1210,4 @@ class FastSpeech2Loss(nn.Layer):
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
- return l1_loss, duration_loss, pitch_loss, energy_loss
+ return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py
index 92aa9dfc7..b398267e6 100644
--- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py
+++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py
@@ -14,6 +14,7 @@
import logging
from pathlib import Path
+from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
@@ -23,6 +24,7 @@ from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
+
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
@@ -31,24 +33,30 @@ logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
- def __init__(self,
- model: Layer,
- optimizer: Optimizer,
- dataloader: DataLoader,
- init_state=None,
- use_masking: bool=False,
- use_weighted_masking: bool=False,
- output_dir: Path=None):
+ def __init__(
+ self,
+ model: Layer,
+ optimizer: Optimizer,
+ dataloader: DataLoader,
+ init_state=None,
+ use_masking: bool=False,
+ spk_loss_scale: float=0.02,
+ use_weighted_masking: bool=False,
+ output_dir: Path=None,
+ enable_spk_cls: bool=False, ):
super().__init__(model, optimizer, dataloader, init_state=None)
self.criterion = FastSpeech2Loss(
- use_masking=use_masking, use_weighted_masking=use_weighted_masking)
+ use_masking=use_masking,
+ use_weighted_masking=use_weighted_masking, )
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
+ self.spk_loss_scale = spk_loss_scale
+ self.enable_spk_cls = enable_spk_cls
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
@@ -60,18 +68,33 @@ class FastSpeech2Updater(StandardUpdater):
if spk_emb is not None:
spk_id = None
- before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
- text=batch["text"],
- text_lengths=batch["text_lengths"],
- speech=batch["speech"],
- speech_lengths=batch["speech_lengths"],
- durations=batch["durations"],
- pitch=batch["pitch"],
- energy=batch["energy"],
- spk_id=spk_id,
- spk_emb=spk_emb)
-
- l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
+ if type(
+ self.model
+ ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
+ with self.model.no_sync():
+ before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
+ text=batch["text"],
+ text_lengths=batch["text_lengths"],
+ speech=batch["speech"],
+ speech_lengths=batch["speech_lengths"],
+ durations=batch["durations"],
+ pitch=batch["pitch"],
+ energy=batch["energy"],
+ spk_id=spk_id,
+ spk_emb=spk_emb)
+ else:
+ before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
+ text=batch["text"],
+ text_lengths=batch["text_lengths"],
+ speech=batch["speech"],
+ speech_lengths=batch["speech_lengths"],
+ durations=batch["durations"],
+ pitch=batch["pitch"],
+ energy=batch["energy"],
+ spk_id=spk_id,
+ spk_emb=spk_emb)
+
+ l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
@@ -82,9 +105,12 @@ class FastSpeech2Updater(StandardUpdater):
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
- olens=olens)
+ olens=olens,
+ spk_logits=spk_logits,
+ spk_ids=spk_id, )
- loss = l1_loss + duration_loss + pitch_loss + energy_loss
+ scaled_speaker_loss = self.spk_loss_scale * speaker_loss
+ loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
optimizer = self.optimizer
optimizer.clear_grad()
@@ -96,11 +122,18 @@ class FastSpeech2Updater(StandardUpdater):
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
+ if self.enable_spk_cls:
+ report("train/speaker_loss", float(speaker_loss))
+ report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
+ losses_dict["energy_loss"] = float(energy_loss)
+ if self.enable_spk_cls:
+ losses_dict["speaker_loss"] = float(speaker_loss)
+ losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
@@ -112,7 +145,9 @@ class FastSpeech2Evaluator(StandardEvaluator):
dataloader: DataLoader,
use_masking: bool=False,
use_weighted_masking: bool=False,
- output_dir: Path=None):
+ spk_loss_scale: float=0.02,
+ output_dir: Path=None,
+ enable_spk_cls: bool=False):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
@@ -120,6 +155,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
+ self.spk_loss_scale = spk_loss_scale
+ self.enable_spk_cls = enable_spk_cls
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
@@ -133,18 +170,33 @@ class FastSpeech2Evaluator(StandardEvaluator):
if spk_emb is not None:
spk_id = None
- before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
- text=batch["text"],
- text_lengths=batch["text_lengths"],
- speech=batch["speech"],
- speech_lengths=batch["speech_lengths"],
- durations=batch["durations"],
- pitch=batch["pitch"],
- energy=batch["energy"],
- spk_id=spk_id,
- spk_emb=spk_emb)
-
- l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
+ if type(
+ self.model
+ ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier:
+ with self.model.no_sync():
+ before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
+ text=batch["text"],
+ text_lengths=batch["text_lengths"],
+ speech=batch["speech"],
+ speech_lengths=batch["speech_lengths"],
+ durations=batch["durations"],
+ pitch=batch["pitch"],
+ energy=batch["energy"],
+ spk_id=spk_id,
+ spk_emb=spk_emb)
+ else:
+ before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
+ text=batch["text"],
+ text_lengths=batch["text_lengths"],
+ speech=batch["speech"],
+ speech_lengths=batch["speech_lengths"],
+ durations=batch["durations"],
+ pitch=batch["pitch"],
+ energy=batch["energy"],
+ spk_id=spk_id,
+ spk_emb=spk_emb)
+
+ l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
@@ -155,19 +207,29 @@ class FastSpeech2Evaluator(StandardEvaluator):
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
- olens=olens, )
- loss = l1_loss + duration_loss + pitch_loss + energy_loss
+ olens=olens,
+ spk_logits=spk_logits,
+ spk_ids=spk_id, )
+
+ scaled_speaker_loss = self.spk_loss_scale * speaker_loss
+ loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
+ if self.enable_spk_cls:
+ report("train/speaker_loss", float(speaker_loss))
+ report("train/scaled_speaker_loss", float(scaled_speaker_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
+ if self.enable_spk_cls:
+ losses_dict["speaker_loss"] = float(speaker_loss)
+ losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
diff --git a/paddlespeech/t2s/modules/adversarial_loss/__init__.py b/paddlespeech/t2s/modules/adversarial_loss/__init__.py
new file mode 100644
index 000000000..abf198b97
--- /dev/null
+++ b/paddlespeech/t2s/modules/adversarial_loss/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py
new file mode 100644
index 000000000..64da16053
--- /dev/null
+++ b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+from paddle.autograd import PyLayer
+
+
+class GradientReversalFunction(PyLayer):
+ """Gradient Reversal Layer from:
+ Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
+
+ Forward pass is the identity function. In the backward pass,
+ the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
+ """
+
+ @staticmethod
+ def forward(ctx, x, lambda_=1):
+ """Forward in networks
+ """
+ ctx.save_for_backward(lambda_)
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grads):
+ """Backward in networks
+ """
+ lambda_, = ctx.saved_tensor()
+ dx = -lambda_ * grads
+ return paddle.clip(dx, min=-0.5, max=0.5)
+
+
+class GradientReversalLayer(nn.Layer):
+ """Gradient Reversal Layer from:
+ Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
+
+ Forward pass is the identity function. In the backward pass,
+ the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
+ """
+
+ def __init__(self, lambda_=1):
+ super(GradientReversalLayer, self).__init__()
+ self.lambda_ = lambda_
+
+ def forward(self, x):
+ """Forward in networks
+ """
+ return GradientReversalFunction.apply(x, self.lambda_)
diff --git a/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py
new file mode 100644
index 000000000..d731b2d27
--- /dev/null
+++ b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning)
+import paddle
+from paddle import nn
+from typeguard import check_argument_types
+
+
+class SpeakerClassifier(nn.Layer):
+ def __init__(
+ self,
+ idim: int,
+ hidden_sc_dim: int,
+ spk_num: int, ):
+ assert check_argument_types()
+ super().__init__()
+ # store hyperparameters
+ self.idim = idim
+ self.hidden_sc_dim = hidden_sc_dim
+ self.spk_num = spk_num
+
+ self.model = nn.Sequential(
+ nn.Linear(self.idim, self.hidden_sc_dim),
+ nn.Linear(self.hidden_sc_dim, self.spk_num))
+
+ def parse_outputs(self, out, text_lengths):
+ mask = paddle.arange(out.shape[1]).expand(
+ [out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1)
+ out = paddle.transpose(out, perm=[2, 0, 1])
+ out = out * mask
+ out = paddle.transpose(out, perm=[1, 2, 0])
+ return out
+
+ def forward(self, encoder_outputs, text_lengths):
+ """
+ encoder_outputs = [batch_size, seq_len, encoder_embedding_size]
+ text_lengths = [batch_size]
+
+ log probabilities of speaker classification = [batch_size, seq_len, spk_num]
+ """
+
+ out = self.model(encoder_outputs)
+ out = self.parse_outputs(out, text_lengths)
+ return out
diff --git a/paddlespeech/text/exps/ernie_linear/test.py b/paddlespeech/text/exps/ernie_linear/test.py
index 4302a1a3b..aa172cc69 100644
--- a/paddlespeech/text/exps/ernie_linear/test.py
+++ b/paddlespeech/text/exps/ernie_linear/test.py
@@ -23,6 +23,7 @@ from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
from yacs.config import CfgNode
+from paddlespeech.t2s.utils import str2bool
from paddlespeech.text.models.ernie_linear import ErnieLinear
from paddlespeech.text.models.ernie_linear import PuncDataset
from paddlespeech.text.models.ernie_linear import PuncDatasetFromErnieTokenizer
@@ -91,9 +92,10 @@ def test(args):
t = classification_report(
test_total_label, test_total_predict, target_names=punc_list)
print(t)
- t2 = evaluation(test_total_label, test_total_predict)
- print('=========================================================')
- print(t2)
+ if args.print_eval:
+ t2 = evaluation(test_total_label, test_total_predict)
+ print('=========================================================')
+ print(t2)
def main():
@@ -101,6 +103,7 @@ def main():
parser = argparse.ArgumentParser(description="Test a ErnieLinear model.")
parser.add_argument("--config", type=str, help="ErnieLinear config file.")
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
+ parser.add_argument("--print_eval", type=str2bool, default=True)
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt
index 978a23d95..09bdb3c1e 100644
--- a/speechx/CMakeLists.txt
+++ b/speechx/CMakeLists.txt
@@ -134,7 +134,7 @@ string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
execute_process(
- COMMAND python -c " \
+ COMMAND python -c "\
import os; \
import paddle; \
include_dir=paddle.sysconfig.get_include(); \
diff --git a/speechx/README.md b/speechx/README.md
index f744defae..a575040db 100644
--- a/speechx/README.md
+++ b/speechx/README.md
@@ -70,3 +70,46 @@ popd
### Deepspeech2 with linear feature
* DecibelNormalizer: there is a small difference between the offline and online db norm. The computation of online db norm reads features chunk by chunk, which causes the feature size to be different different with offline db norm. In `normalizer.cc:73`, the `samples.size()` is different, which causes the different result.
+
+## FAQ
+
+1. No moudle named `paddle`.
+
+```
+CMake Error at CMakeLists.txt:119 (string):
+ string sub-command STRIP requires two arguments.
+
+
+Traceback (most recent call last):
+ File "", line 1, in
+ModuleNotFoundError: No module named 'paddle'
+-- PADDLE_COMPILE_FLAGS=
+CMake Error at CMakeLists.txt:131 (string):
+ string sub-command STRIP requires two arguments.
+
+
+ File "", line 1
+ import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);
+ ^
+```
+
+please install paddlepaddle >= 2.4rc
+
+2. `u2_recognizer_main: error while loading shared libraries: liblibpaddle.so: cannot open shared object file: No such file or directory`
+
+
+```
+cd $YOUR_ENV_PATH/lib/python3.7/site-packages/paddle/fluid
+patchelf --set-soname libpaddle.so libpaddle.so
+```
+
+3. `u2_recognizer_main: error while loading shared libraries: libgfortran.so.5: cannot open shared object file: No such file or directory`
+
+```
+# my gcc version is 8.2
+apt-get install gfortran-8
+```
+
+4. `Undefined reference to '_gfortran_concat_string'`
+
+using gcc 8.2, gfortran 8.2.
diff --git a/speechx/build.sh b/speechx/build.sh
index 7655f9635..e0a386752 100755
--- a/speechx/build.sh
+++ b/speechx/build.sh
@@ -20,4 +20,4 @@ fi
mkdir -p build
cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
-cmake --build build -j
+cmake --build build
diff --git a/speechx/examples/u2pp_ol/wenetspeech/.gitignore b/speechx/examples/u2pp_ol/wenetspeech/.gitignore
index 02c0cc21f..bbd86a25b 100644
--- a/speechx/examples/u2pp_ol/wenetspeech/.gitignore
+++ b/speechx/examples/u2pp_ol/wenetspeech/.gitignore
@@ -1,3 +1,2 @@
data
-utils
exp
diff --git a/speechx/examples/u2pp_ol/wenetspeech/RESULTS.md b/speechx/examples/u2pp_ol/wenetspeech/RESULTS.md
new file mode 100644
index 000000000..6a8e8c46d
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/RESULTS.md
@@ -0,0 +1,36 @@
+# aishell test
+
+7176 utts, duration 36108.9 sec.
+
+## Attention Rescore
+
+### u2++ FP32
+
+#### CER
+
+```
+Overall -> 5.75 % N=104765 C=99035 S=5587 D=143 I=294
+Mandarin -> 5.75 % N=104762 C=99035 S=5584 D=143 I=294
+English -> 0.00 % N=0 C=0 S=0 D=0 I=0
+Other -> 100.00 % N=3 C=0 S=3 D=0 I=0
+```
+
+#### RTF
+
+> RTF with feature and decoder which is more end to end.
+
+* Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz, support `avx512_vnni`
+
+```
+I1027 10:52:38.662868 51665 u2_recognizer_main.cc:122] total wav duration is: 36108.9 sec
+I1027 10:52:38.662858 51665 u2_recognizer_main.cc:121] total cost:11169.1 sec
+I1027 10:52:38.662876 51665 u2_recognizer_main.cc:123] RTF is: 0.309318
+```
+
+* Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz, not support `avx512_vnni`
+
+```
+I1026 16:13:26.247121 48038 u2_recognizer_main.cc:123] total wav duration is: 36108.9 sec
+I1026 16:13:26.247130 48038 u2_recognizer_main.cc:124] total decode cost:13656.7 sec
+I1026 16:13:26.247138 48038 u2_recognizer_main.cc:125] RTF is: 0.378208
+```
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh b/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh
index c17cdbe65..e9c81009c 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh
@@ -8,7 +8,7 @@ exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
-model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh
index 4341cec8b..e181951e3 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh
@@ -3,29 +3,40 @@ set -e
. path.sh
+nj=20
+stage=-1
+stop_stage=100
+
+. utils/parse_options.sh
+
data=data
exp=exp
-nj=20
mkdir -p $exp
+
ckpt_dir=./data/model
-model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp=aishell_test.scp
-cmvn_json2kaldi_main \
- --json_file $model_dir/mean_std.json \
- --cmvn_write_path $exp/cmvn.ark \
- --binary=false
-
-echo "convert json cmvn to kaldi ark."
-
-./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
-utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
-compute_fbank_main \
- --num_bins 80 \
- --cmvn_file=$exp/cmvn.ark \
- --streaming_chunk=36 \
- --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
- --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ cmvn_json2kaldi_main \
+ --json_file $model_dir/mean_std.json \
+ --cmvn_write_path $exp/cmvn.ark \
+ --binary=false
+
+ echo "convert json cmvn to kaldi ark."
+fi
-echo "compute fbank feature."
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
+
+ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
+ compute_fbank_main \
+ --num_bins 80 \
+ --cmvn_file=$exp/cmvn.ark \
+ --streaming_chunk=36 \
+ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
+ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
+
+ echo "compute fbank feature."
+fi
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh b/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh
index 4419201cf..5455b5c9b 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh
@@ -8,7 +8,7 @@ data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
-model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh
index f4553f2ab..344fbcbce 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh
@@ -1,16 +1,15 @@
#!/bin/bash
set -e
-. path.sh
-
data=data
exp=exp
nj=20
+. utils/parse_options.sh
mkdir -p $exp
ckpt_dir=./data/model
-model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh
new file mode 100755
index 000000000..1ce403a3c
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+set -e
+
+data=data
+exp=exp
+nj=20
+
+. utils/parse_options.sh
+
+mkdir -p $exp
+ckpt_dir=./data/model
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model/
+aishell_wav_scp=aishell_test.scp
+text=$data/test/text
+
+./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
+
+utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \
+u2_recognizer_main \
+ --use_fbank=true \
+ --num_bins=80 \
+ --cmvn_file=$exp/cmvn.ark \
+ --model_path=$model_dir/export \
+ --vocab_path=$model_dir/unit.txt \
+ --nnet_decoder_chunk=16 \
+ --receptive_field_length=7 \
+ --subsampling_rate=4 \
+ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
+ --result_wspecifier=ark,t:$data/split${nj}/JOB/recognizer.quant.rsl.ark
+
+
+cat $data/split${nj}/*/recognizer.quant.rsl.ark > $exp/aishell.recognizer.quant.rsl
+utils/compute-wer.py --char=1 --v=1 $text $exp/aishell.recognizer.quant.rsl > $exp/aishell.recognizer.quant.err
+echo "recognizer quant test have finished!!!"
+echo "please checkout in $exp/aishell.recognizer.quant.err"
+tail -n 7 $exp/aishell.recognizer.quant.err
diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh
index 12e3af95a..2bc855dec 100755
--- a/speechx/examples/u2pp_ol/wenetspeech/run.sh
+++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh
@@ -1,12 +1,11 @@
#!/bin/bash
-set +x
set -e
. path.sh
nj=40
-stage=0
-stop_stage=5
+stage=-1
+stop_stage=100
. utils/parse_options.sh
@@ -14,7 +13,7 @@ stop_stage=5
data=data
exp=exp
mkdir -p $exp $data
-
+aishell_wav_scp=aishell_test.scp
# 1. compile
if [ ! -d ${SPEECHX_BUILD} ]; then
@@ -25,17 +24,28 @@ fi
ckpt_dir=$data/model
-model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
+model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model/
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
- # download model
- if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
+ # download u2pp model
+ if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
- wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
- tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz
+ tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model.tar.gz
+
+ popd
+ fi
+
+ # download u2pp quant model
+ if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz ]; then
+ mkdir -p $ckpt_dir
+ pushd $ckpt_dir
+
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz
+ tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_quant_1.3.0.model.tar.gz
popd
fi
@@ -73,4 +83,4 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./loca/recognizer.sh
-fi
\ No newline at end of file
+fi
diff --git a/speechx/examples/u2pp_ol/wenetspeech/utils b/speechx/examples/u2pp_ol/wenetspeech/utils
new file mode 120000
index 000000000..c2519a9dd
--- /dev/null
+++ b/speechx/examples/u2pp_ol/wenetspeech/utils
@@ -0,0 +1 @@
+../../../../utils/
\ No newline at end of file
diff --git a/speechx/requirement.txt b/speechx/requirement.txt
new file mode 100644
index 000000000..6a6db0960
--- /dev/null
+++ b/speechx/requirement.txt
@@ -0,0 +1 @@
+paddlepaddle>=2.4rc
diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
index 03a7c1336..07e8e5608 100644
--- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
+++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc
@@ -69,20 +69,30 @@ void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr& decodable) {
+ double search_cost = 0.0;
+ double feat_nnet_cost = 0.0;
while (1) {
// forward frame by frame
+ kaldi::Timer timer;
std::vector frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
+ feat_nnet_cost += timer.Elapsed();
if (flag == false) {
- VLOG(1) << "decoder advance decode exit." << frame_prob.size();
+ VLOG(3) << "decoder advance decode exit." << frame_prob.size();
break;
}
+ timer.Reset();
std::vector> likelihood;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
+ search_cost += timer.Elapsed();
+
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
}
+ VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
+ << " sec.";
+ VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec.";
}
static bool PrefixScoreCompare(
diff --git a/speechx/speechx/frontend/audio/assembler.cc b/speechx/speechx/frontend/audio/assembler.cc
index 56dfc3aaf..9d5fc4036 100644
--- a/speechx/speechx/frontend/audio/assembler.cc
+++ b/speechx/speechx/frontend/audio/assembler.cc
@@ -40,7 +40,9 @@ void Assembler::Accept(const kaldi::VectorBase& inputs) {
// pop feature chunk
bool Assembler::Read(kaldi::Vector* feats) {
+ kaldi::Timer timer;
bool result = Compute(feats);
+ VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec.";
return result;
}
@@ -51,14 +53,14 @@ bool Assembler::Compute(Vector* feats) {
Vector feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) {
- VLOG(1) << "result: " << result
+ VLOG(3) << "result: " << result
<< " feature dim: " << feature.Dim();
if (IsFinished() == false) {
- VLOG(1) << "finished reading feature. cache size: "
+ VLOG(3) << "finished reading feature. cache size: "
<< feature_cache_.size();
return false;
} else {
- VLOG(1) << "break";
+ VLOG(3) << "break";
break;
}
}
@@ -67,11 +69,11 @@ bool Assembler::Compute(Vector* feats) {
feature_cache_.push(feature);
nframes_ += 1;
- VLOG(1) << "nframes: " << nframes_;
+ VLOG(3) << "nframes: " << nframes_;
}
if (feature_cache_.size() < receptive_filed_length_) {
- VLOG(1) << "feature_cache less than receptive_filed_lenght. "
+ VLOG(3) << "feature_cache less than receptive_filed_lenght. "
<< feature_cache_.size() << ": " << receptive_filed_length_;
return false;
}
@@ -87,7 +89,7 @@ bool Assembler::Compute(Vector* feats) {
int32 this_chunk_size =
std::min(static_cast(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
- VLOG(1) << "read " << this_chunk_size << " feat.";
+ VLOG(3) << "read " << this_chunk_size << " feat.";
int32 counter = 0;
while (counter < this_chunk_size) {
diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc
index 61ef8841d..c6a91f4b3 100644
--- a/speechx/speechx/frontend/audio/audio_cache.cc
+++ b/speechx/speechx/frontend/audio/audio_cache.cc
@@ -38,6 +38,7 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) {
}
void AudioCache::Accept(const VectorBase& waves) {
+ kaldi::Timer timer;
std::unique_lock lock(mutex_);
while (size_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
@@ -48,11 +49,13 @@ void AudioCache::Accept(const VectorBase& waves) {
if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));
}
size_ += waves.Dim();
+ VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. "
+ << waves.Dim() << " samples.";
}
bool AudioCache::Read(Vector* waves) {
- size_t chunk_size = waves->Dim();
kaldi::Timer timer;
+ size_t chunk_size = waves->Dim();
std::unique_lock lock(mutex_);
while (chunk_size > size_) {
// when audio is empty and no more data feed
@@ -86,9 +89,11 @@ bool AudioCache::Read(Vector* waves) {
offset_ = (offset_ + chunk_size) % ring_buffer_.size();
nsamples_ += chunk_size;
- VLOG(1) << "nsamples readed: " << nsamples_;
+ VLOG(3) << "nsamples readed: " << nsamples_;
ready_feed_condition_.notify_one();
+ VLOG(1) << "AudioCache::Read cost: " << timer.Elapsed() << " sec. "
+ << chunk_size << " samples.";
return true;
}
diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/frontend/audio/cmvn.cc
index 3d80e0011..a4d861d2d 100644
--- a/speechx/speechx/frontend/audio/cmvn.cc
+++ b/speechx/speechx/frontend/audio/cmvn.cc
@@ -50,8 +50,11 @@ bool CMVN::Read(kaldi::Vector* feats) {
if (base_extractor_->Read(feats) == false || feats->Dim() == 0) {
return false;
}
+
// appply cmvn
+ kaldi::Timer timer;
Compute(feats);
+ VLOG(1) << "CMVN::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
diff --git a/speechx/speechx/frontend/audio/data_cache.h b/speechx/speechx/frontend/audio/data_cache.h
index 5f5cd51b4..f538df1dd 100644
--- a/speechx/speechx/frontend/audio/data_cache.h
+++ b/speechx/speechx/frontend/audio/data_cache.h
@@ -27,27 +27,32 @@ namespace ppspeech {
// pre-recorded audio/feature
class DataCache : public FrontendInterface {
public:
- DataCache() { finished_ = false; }
+ DataCache() : finished_{false}, dim_{0} {}
// accept waves/feats
- virtual void Accept(const kaldi::VectorBase& inputs) {
+ void Accept(const kaldi::VectorBase& inputs) override {
data_ = inputs;
+ SetDim(data_.Dim());
}
- virtual bool Read(kaldi::Vector* feats) {
+ bool Read(kaldi::Vector* feats) override {
if (data_.Dim() == 0) {
return false;
}
(*feats) = data_;
data_.Resize(0);
+ SetDim(data_.Dim());
return true;
}
- virtual void SetFinished() { finished_ = true; }
- virtual bool IsFinished() const { return finished_; }
- virtual size_t Dim() const { return dim_; }
+ void SetFinished() override { finished_ = true; }
+ bool IsFinished() const override { return finished_; }
+ size_t Dim() const override { return dim_; }
void SetDim(int32 dim) { dim_ = dim; }
- virtual void Reset() { finished_ = true; }
+ void Reset() override {
+ finished_ = true;
+ dim_ = 0;
+ }
private:
kaldi::Vector data_;
diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc
index 3f05eae62..5110d7046 100644
--- a/speechx/speechx/frontend/audio/feature_cache.cc
+++ b/speechx/speechx/frontend/audio/feature_cache.cc
@@ -34,6 +34,7 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
void FeatureCache::Accept(const kaldi::VectorBase& inputs) {
// read inputs
base_extractor_->Accept(inputs);
+
// feed current data
bool result = false;
do {
@@ -62,6 +63,7 @@ bool FeatureCache::Read(kaldi::Vector* feats) {
feats->CopyFromVec(cache_.front());
cache_.pop();
ready_feed_condition_.notify_one();
+ VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
@@ -72,9 +74,11 @@ bool FeatureCache::Compute() {
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
+ kaldi::Timer timer;
+
int32 num_chunk = feature.Dim() / dim_;
nframe_ += num_chunk;
- VLOG(1) << "nframe computed: " << nframe_;
+ VLOG(3) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
@@ -92,7 +96,10 @@ bool FeatureCache::Compute() {
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
}
- return result;
+
+ VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
+ << num_chunk << " feats.";
+ return true;
}
} // namespace ppspeech
\ No newline at end of file
diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h
index bd8692251..a4ebd6047 100644
--- a/speechx/speechx/frontend/audio/feature_cache.h
+++ b/speechx/speechx/frontend/audio/feature_cache.h
@@ -58,7 +58,7 @@ class FeatureCache : public FrontendInterface {
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset();
- VLOG(1) << "feature cache reset: cache size: " << cache_.size();
+ VLOG(3) << "feature cache reset: cache size: " << cache_.size();
}
private:
diff --git a/speechx/speechx/frontend/audio/feature_common_inl.h b/speechx/speechx/frontend/audio/feature_common_inl.h
index b86f79918..dcf44ef61 100644
--- a/speechx/speechx/frontend/audio/feature_common_inl.h
+++ b/speechx/speechx/frontend/audio/feature_common_inl.h
@@ -34,6 +34,7 @@ bool StreamingFeatureTpl::Read(kaldi::Vector* feats) {
bool flag = base_extractor_->Read(&wav);
if (flag == false || wav.Dim() == 0) return false;
+ kaldi::Timer timer;
// append remaned waves
int32 wav_len = wav.Dim();
int32 left_len = remained_wav_.Dim();
@@ -52,6 +53,8 @@ bool StreamingFeatureTpl::Read(kaldi::Vector* feats) {
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
+ VLOG(1) << "StreamingFeatureTpl::Read cost: " << timer.Elapsed()
+ << " sec.";
return true;
}
diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc
index 11d60d3e2..7f6859082 100644
--- a/speechx/speechx/nnet/decodable.cc
+++ b/speechx/speechx/nnet/decodable.cc
@@ -68,9 +68,10 @@ bool Decodable::AdvanceChunk() {
Vector features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
- VLOG(1) << "decodable exit;";
+ VLOG(3) << "decodable exit;";
return false;
}
+ VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec.";
VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats.";
// forward feats
@@ -88,7 +89,8 @@ bool Decodable::AdvanceChunk() {
// update state, decoding frame.
frame_offset_ = frames_ready_;
frames_ready_ += nnet_out_cache_.NumRows();
- VLOG(2) << "Forward feat chunk cost: " << timer.Elapsed() << " sec.";
+ VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed()
+ << " sec.";
return true;
}
@@ -115,7 +117,7 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs,
// read one frame likelihood
bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) {
if (EnsureFrameHaveComputed(frame) == false) {
- VLOG(1) << "framelikehood exit.";
+ VLOG(3) << "framelikehood exit.";
return false;
}
@@ -168,7 +170,9 @@ void Decodable::Reset() {
void Decodable::AttentionRescoring(const std::vector>& hyps,
float reverse_weight,
std::vector* rescoring_score) {
+ kaldi::Timer timer;
nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score);
+ VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec.";
}
} // namespace ppspeech
\ No newline at end of file
diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc
index 19cb85fda..7707406a1 100644
--- a/speechx/speechx/nnet/u2_nnet.cc
+++ b/speechx/speechx/nnet/u2_nnet.cc
@@ -154,7 +154,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear();
- VLOG(1) << "u2nnet reset";
+ VLOG(3) << "u2nnet reset";
}
// Debug API
@@ -168,6 +168,7 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) {
void U2Nnet::FeedForward(const kaldi::Vector& features,
const int32& feature_dim,
NnetOut* out) {
+ kaldi::Timer timer;
std::vector chunk_feats(features.Data(),
features.Data() + features.Dim());
@@ -179,6 +180,8 @@ void U2Nnet::FeedForward(const kaldi::Vector& features,
std::memcpy(out->logprobs.Data(),
ctc_probs.data(),
ctc_probs.size() * sizeof(kaldi::BaseFloat));
+ VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. "
+ << chunk_feats.size() / feature_dim << " frames.";
}
@@ -415,7 +418,6 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps,
#ifdef USE_PROFILING
RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1);
#endif
-
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
@@ -627,7 +629,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps,
// combinded left-to-right and right-to-lfet score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
- VLOG(1) << "hyp " << i << " " << hyp.size() << " score: " << score
+ VLOG(3) << "hyp " << i << " " << hyp.size() << " score: " << score
<< " r_score: " << r_score
<< " reverse_weight: " << reverse_weight
<< " final score: " << (*rescoring_score)[i];
@@ -639,7 +641,7 @@ void U2Nnet::EncoderOuts(
std::vector>* encoder_out) const {
// list of (B=1,T,D)
int size = encoder_outs_.size();
- VLOG(1) << "encoder_outs_ size: " << size;
+ VLOG(3) << "encoder_outs_ size: " << size;
for (int i = 0; i < size; i++) {
const paddle::Tensor& item = encoder_outs_[i];
@@ -649,7 +651,7 @@ void U2Nnet::EncoderOuts(
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
- VLOG(1) << "encoder out " << i << " shape: (" << B << "," << T << ","
+ VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
const float* this_tensor_ptr = item.data();
diff --git a/speechx/speechx/recognizer/u2_recognizer.cc b/speechx/speechx/recognizer/u2_recognizer.cc
index 382f622f2..d1d308ebd 100644
--- a/speechx/speechx/recognizer/u2_recognizer.cc
+++ b/speechx/speechx/recognizer/u2_recognizer.cc
@@ -67,7 +67,10 @@ void U2Recognizer::ResetContinuousDecoding() {
void U2Recognizer::Accept(const VectorBase& waves) {
+ kaldi::Timer timer;
feature_pipeline_->Accept(waves);
+ VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim()
+ << " samples.";
}
@@ -78,9 +81,7 @@ void U2Recognizer::Decode() {
void U2Recognizer::Rescoring() {
// Do attention Rescoring
- kaldi::Timer timer;
AttentionRescoring();
- VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << " sec.";
}
void U2Recognizer::UpdateResult(bool finish) {
@@ -181,15 +182,13 @@ void U2Recognizer::AttentionRescoring() {
return;
}
- kaldi::Timer timer;
std::vector rescoring_score;
decodable_->AttentionRescoring(
hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score);
- VLOG(1) << "Attention Rescoring takes " << timer.Elapsed() << " sec.";
// combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; i++) {
- VLOG(1) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
+ VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i]
<< " ctc_score: " << result_[i].score
<< " rescoring_weight: " << opts_.decoder_opts.rescoring_weight
<< " ctc_weight: " << opts_.decoder_opts.ctc_weight;
@@ -197,12 +196,12 @@ void U2Recognizer::AttentionRescoring() {
opts_.decoder_opts.rescoring_weight * rescoring_score[i] +
opts_.decoder_opts.ctc_weight * result_[i].score;
- VLOG(1) << "hyp: " << result_[0].sentence
+ VLOG(3) << "hyp: " << result_[0].sentence
<< " score: " << result_[0].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
- VLOG(1) << "result: " << result_[0].sentence
+ VLOG(3) << "result: " << result_[0].sentence
<< " score: " << result_[0].score;
}
diff --git a/speechx/speechx/recognizer/u2_recognizer_main.cc b/speechx/speechx/recognizer/u2_recognizer_main.cc
index 5cb8dbb15..d7c584074 100644
--- a/speechx/speechx/recognizer/u2_recognizer_main.cc
+++ b/speechx/speechx/recognizer/u2_recognizer_main.cc
@@ -31,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
+ double tot_decode_time = 0.0;
kaldi::SequentialTableReader wav_reader(
FLAGS_wav_rspecifier);
@@ -47,9 +48,7 @@ int main(int argc, char* argv[]) {
ppspeech::U2RecognizerResource::InitFromFlags();
ppspeech::U2Recognizer recognizer(resource);
- kaldi::Timer timer;
for (; !wav_reader.Done(); wav_reader.Next()) {
- kaldi::Timer local_timer;
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@@ -65,6 +64,8 @@ int main(int argc, char* argv[]) {
int sample_offset = 0;
int cnt = 0;
+ kaldi::Timer timer;
+ kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
@@ -95,6 +96,8 @@ int main(int argc, char* argv[]) {
// second pass decoding
recognizer.Rescoring();
+ tot_decode_time += timer.Elapsed();
+
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
@@ -115,10 +118,8 @@ int main(int argc, char* argv[]) {
++num_done;
}
- double elapsed = timer.Elapsed();
-
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
- LOG(INFO) << "total cost:" << elapsed << " sec";
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
- LOG(INFO) << "RTF is: " << elapsed / tot_wav_duration;
+ LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
+ LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}
diff --git a/tools/watermark.py b/tools/watermark.py
new file mode 100644
index 000000000..fc592d5bc
--- /dev/null
+++ b/tools/watermark.py
@@ -0,0 +1,20 @@
+# add watermark for text
+def watermark(content, pattern):
+ m = list(zip(pattern * (len(content) // len(pattern) + 1), content))
+ return ''.join([x for t in m
+ for x in t] + [pattern[len(content) % len(pattern)]])
+
+
+# remove cyclic watermark in text
+def iwatermark(content):
+ e = [x for i, x in enumerate(content) if i % 2 == 0]
+ o = [x for i, x in enumerate(content) if i % 2 != 0]
+ for i in range(1, len(e) // 2 + 1):
+ if e[i:] == e[:-i]:
+ return ''.join(o)
+ return ''.join(e)
+
+
+if __name__ == "__main__":
+ print(watermark('跟世龙对齐 Triton 开发计划', 'hbzs'))
+ print(iwatermark('h跟b世z龙s对h齐b zTsrhibtzosnh b开z发s计h划b'))