diff --git a/README.md b/README.md
index 79b86e9ff..e1f57fcaf 100644
--- a/README.md
+++ b/README.md
@@ -157,6 +157,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update
+- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web).
- ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder.
- ⚡ 2022.08.25: Release TTS [finetune](./examples/other/tts_finetune/tts3) example.
diff --git a/README_cn.md b/README_cn.md
index 3d60882b2..1e932201f 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -179,6 +179,7 @@
### 近期更新
+- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning.
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
- ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。
diff --git a/demos/speech_web/README.md b/demos/speech_web/README.md
index 89d22382a..572781ab6 100644
--- a/demos/speech_web/README.md
+++ b/demos/speech_web/README.md
@@ -21,14 +21,14 @@ Paddle Speech Demo 是一个以 PaddleSpeech 的语音交互功能为主体开
+ 小数据微调:基于小数据集的微调方案,内置用12句话标贝中文女声微调示例,你也可以通过一键重置,录制自己的声音,注意在安静环境下录制,效果会更好。你可以在 [【Finetune your own AM based on FastSpeech2 with AISHELL-3】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/tts_finetune/tts3)中尝试使用自己的数据集进行微调。
-+ ENIRE-SAT:语言-语音跨模态大模型 ENIRE-SAT 可视化展示示例,支持个性化合成,跨语言语音合成(音频为中文则输入英文文本进行合成),语音编辑(修改音频文字中间的结果)功能。 ENIRE-SAT 更多实现细节,可以参考:
++ ERNIE-SAT:语言-语音跨模态大模型 ERNIE-SAT 可视化展示示例,支持个性化合成,跨语言语音合成(音频为中文则输入英文文本进行合成),语音编辑(修改音频文字中间的结果)功能。 ERNIE-SAT 更多实现细节,可以参考:
+ [【ERNIE-SAT with AISHELL-3 dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/ernie_sat)
+ [【ERNIE-SAT with with AISHELL3 and VCTK datasets】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3_vctk/ernie_sat)
+ [【ERNIE-SAT with VCTK dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/ernie_sat)
运行效果:
- 
+ 
diff --git a/demos/speech_web/web_client/src/components/Experience.vue b/demos/speech_web/web_client/src/components/Experience.vue
index f593c0c14..ca0e1440f 100644
--- a/demos/speech_web/web_client/src/components/Experience.vue
+++ b/demos/speech_web/web_client/src/components/Experience.vue
@@ -7,7 +7,7 @@ import VPRT from './SubMenu/VPR/VPRT.vue'
import IET from './SubMenu/IE/IET.vue'
import VoiceCloneT from './SubMenu/VoiceClone/VoiceClone.vue'
-import ENIRE_SATT from './SubMenu/ENIRE_SAT/ENIRE_SAT.vue'
+import ERNIE_SATT from './SubMenu/ERNIE_SAT/ERNIE_SAT.vue'
import FineTuneT from './SubMenu/FineTune/FineTune.vue'
@@ -47,8 +47,8 @@ import FineTuneT from './SubMenu/FineTune/FineTune.vue'
-
-
+
+
diff --git a/demos/speech_web/web_client/src/components/SubMenu/ENIRE_SAT/ENIRE_SAT.vue b/demos/speech_web/web_client/src/components/SubMenu/ERNIE_SAT/ERNIE_SAT.vue
similarity index 100%
rename from demos/speech_web/web_client/src/components/SubMenu/ENIRE_SAT/ENIRE_SAT.vue
rename to demos/speech_web/web_client/src/components/SubMenu/ERNIE_SAT/ERNIE_SAT.vue
diff --git a/docs/source/reference.md b/docs/source/reference.md
index 0d36d96f7..9a47a2302 100644
--- a/docs/source/reference.md
+++ b/docs/source/reference.md
@@ -28,6 +28,8 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [speechbrain](https://github.com/speechbrain/speechbrain/blob/develop/LICENSE)
- Apache-2.0 License
- ECAPA-TDNN SV model
+- ASR with CTC and pre-trained wav2vec2 models.
+
* [chainer](https://github.com/chainer/chainer/blob/master/LICENSE)
- MIT License
@@ -43,3 +45,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [g2pW](https://github.com/GitYCC/g2pW/blob/master/LICENCE)
- Apache-2.0 license
+
+*[transformers](https://github.com/huggingface/transformers)
+- Apache-2.0 License
+- Wav2vec2.0
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index a2456f1fe..4e76da033 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -18,6 +18,12 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2) | python |
+### Self-Supervised Pre-trained Model
+Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link |
+:-------------:| :------------:| :-----: | -----: | :-----: |:-----:| :-----: | :-----: | :-----: |
+[Wav2vec2-large-960h-lv60-self Model](https://paddlespeech.bj.bcebos.com/wav2vec/wav2vec2-large-960h-lv60-self.pdparams) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | - | 1.18 GB |Pre-trained Wav2vec2.0 Model | - | - | - |
+[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
+
### Language Model based on NGram
Language Model | Training Data | Token-based | Size | Descriptions
:------------:| :------------:|:------------: | :------------: | :------------:
diff --git a/examples/other/tts_finetune/tts3/README.md b/examples/other/tts_finetune/tts3/README.md
index ceb8e7970..fa691764c 100644
--- a/examples/other/tts_finetune/tts3/README.md
+++ b/examples/other/tts_finetune/tts3/README.md
@@ -7,7 +7,7 @@ For more information on training Fastspeech2 with AISHELL-3, You can refer [exam
## Prepare
### Download Pretrained model
Assume the path to the model is `./pretrained_models`.
-If you want to finetune Chinese data, you need to download Fastspeech2 pretrained model with AISHELL-3: [fastspeech2_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip) for finetuning. Download HiFiGAN pretrained model with aishell3: [hifigan_aishell3_ckpt_0.2.0](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip) for synthesis.
+If you want to finetune Chinese pretrained model, you need to download Fastspeech2 pretrained model with AISHELL-3: [fastspeech2_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_1.1.0.zip) for finetuning. Download HiFiGAN pretrained model with aishell3: [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip) for synthesis.
```bash
mkdir -p pretrained_models && cd pretrained_models
@@ -21,7 +21,7 @@ cd ../
```
-If you want to finetune English data, you need to download Fastspeech2 pretrained model with VCTK: [fastspeech2_vctk_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_ckpt_1.2.0.zip) for finetuning. Download HiFiGAN pretrained model with VCTK: [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip) for synthesis.
+If you want to finetune English pretrained model, you need to download Fastspeech2 pretrained model with VCTK: [fastspeech2_vctk_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_vctk_ckpt_1.2.0.zip) for finetuning. Download HiFiGAN pretrained model with VCTK: [hifigan_vctk_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip) for synthesis.
```bash
mkdir -p pretrained_models && cd pretrained_models
@@ -34,6 +34,59 @@ unzip hifigan_vctk_ckpt_0.2.0.zip
cd ../
```
+If you want to finetune Chinese-English Mixed pretrained model, you need to download Fastspeech2 pretrained model with mix datasets: [fastspeech2_mix_ckpt_1.2.0.zip](https://paddlespeech.bj.bcebos.com/t2s/chinse_english_mixed/models/fastspeech2_mix_ckpt_1.2.0.zip) for finetuning. Download HiFiGAN pretrained model with aishell3: [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip) for synthesis.
+
+```bash
+mkdir -p pretrained_models && cd pretrained_models
+# pretrained fastspeech2 model
+wget https://paddlespeech.bj.bcebos.com/t2s/chinse_english_mixed/models/fastspeech2_mix_ckpt_1.2.0.zip
+unzip fastspeech2_mix_ckpt_1.2.0.zip
+# pretrained hifigan model
+wget https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip
+unzip hifigan_aishell3_ckpt_0.2.0.zip
+cd ../
+```
+
+### Prepare your data
+Assume the path to the dataset is `./input` which contains a speaker folder. Speaker folder contains audio files (*.wav) and label file (labels.txt). The format of the audio file is wav. The format of the label file is: utt_id|pronunciation.
+
+If you want to finetune Chinese pretrained model, you need to prepare Chinese data. Chinese label example:
+```
+000001|ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
+```
+
+Here is an example of the first 200 data of csmsc.
+
+```bash
+mkdir -p input && cd input
+wget https://paddlespeech.bj.bcebos.com/datasets/csmsc_mini.zip
+unzip csmsc_mini.zip
+cd ../
+```
+
+If you want to finetune English pretrained model, you need to prepare English data. English label example:
+```
+LJ001-0001|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
+```
+
+Here is an example of the first 200 data of ljspeech.
+
+```bash
+mkdir -p input && cd input
+wget https://paddlespeech.bj.bcebos.com/datasets/ljspeech_mini.zip
+unzip ljspeech_mini.zip
+cd ../
+```
+
+If you want to finetune Chinese-English Mixed pretrained model, you need to prepare Chinese data or English data. Here is an example of the first 12 data of SSB0005 (the speaker of aishell3).
+
+```bash
+mkdir -p input && cd input
+wget https://paddlespeech.bj.bcebos.com/datasets/SSB0005_mini.zip
+unzip SSB0005_mini.zip
+cd ../
+```
+
### Download MFA tools and pretrained model
Assume the path to the MFA tool is `./tools`. Download [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz).
@@ -46,7 +99,7 @@ cp montreal-forced-aligner/lib/libpython3.6m.so.1.0 montreal-forced-aligner/lib/
mkdir -p aligner && cd aligner
```
-If you want to finetune Chinese data, you need to download pretrained MFA models with aishell3: [aishell3_model.zip](https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/aishell3_model.zip) and unzip it.
+If you want to get mfa result of Chinese data, you need to download pretrained MFA models with aishell3: [aishell3_model.zip](https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/aishell3_model.zip) and unzip it.
```bash
# pretrained mfa model for Chinese data
@@ -56,30 +109,17 @@ wget https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/simple.lexicon
cd ../../
```
-If you want to finetune English data, you need to download pretrained MFA models with vctk: [vctk_model.zip](https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/vctk_model.zip) and unzip it.
+If you want to get mfa result of English data, you need to download pretrained MFA models with vctk: [vctk_model.zip](https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/vctk_model.zip) and unzip it.
```bash
-# pretrained mfa model for Chinese data
+# pretrained mfa model for English data
wget https://paddlespeech.bj.bcebos.com/MFA/ernie_sat/vctk_model.zip
unzip vctk_model.zip
wget https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/cmudict-0.7b
cd ../../
```
-### Prepare your data
-Assume the path to the dataset is `./input` which contains a speaker folder. Speaker folder contains audio files (*.wav) and label file (labels.txt). The format of the audio file is wav. The format of the label file is: utt_id|pronunciation.
-
-If you want to finetune Chinese data, Chinese label example: 000001|ka2 er2 pu3 pei2 wai4 sun1 wan2 hua2 ti1
-Here is an example of the first 200 data of csmsc.
-
-```bash
-mkdir -p input && cd input
-wget https://paddlespeech.bj.bcebos.com/datasets/csmsc_mini.zip
-unzip csmsc_mini.zip
-cd ../
-```
-
-When "Prepare" done. The structure of the current directory is listed below.
+When "Prepare" done. The structure of the current directory is similar to the following.
```text
├── input
│ ├── csmsc_mini
@@ -119,56 +159,6 @@ When "Prepare" done. The structure of the current directory is listed below.
```
-If you want to finetune English data, English label example: LJ001-0001|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
-Here is an example of the first 200 data of ljspeech.
-
-```bash
-mkdir -p input && cd input
-wget https://paddlespeech.bj.bcebos.com/datasets/ljspeech_mini.zip
-unzip ljspeech_mini.zip
-cd ../
-```
-
-When "Prepare" done. The structure of the current directory is listed below.
-```text
-├── input
-│ ├── ljspeech_mini
-│ │ ├── LJ001-0001.wav
-│ │ ├── LJ001-0002.wav
-│ │ ├── LJ001-0003.wav
-│ │ ├── ...
-│ │ ├── LJ002-0014.wav
-│ │ ├── labels.txt
-│ └── ljspeech_mini.zip
-├── pretrained_models
-│ ├── fastspeech2_vctk_ckpt_1.2.0
-│ │ ├── default.yaml
-│ │ ├── energy_stats.npy
-│ │ ├── phone_id_map.txt
-│ │ ├── pitch_stats.npy
-│ │ ├── snapshot_iter_66200.pdz
-│ │ ├── speaker_id_map.txt
-│ │ └── speech_stats.npy
-│ ├── fastspeech2_vctk_ckpt_1.2.0.zip
-│ ├── hifigan_vctk_ckpt_0.2.0
-│ │ ├── default.yaml
-│ │ ├── feats_stats.npy
-│ │ └── snapshot_iter_2500000.pdz
-│ └── hifigan_vctk_ckpt_0.2.0.zip
-└── tools
- ├── aligner
- │ ├── vctk_model
- │ ├── vctk_model.zip
- │ └── cmudict-0.7b
- ├── montreal-forced-aligner
- │ ├── bin
- │ ├── lib
- │ └── pretrained_models
- └── montreal-forced-aligner_linux.tar.gz
- ...
-
-```
-
### Set finetune.yaml
`conf/finetune.yaml` contains some configurations for fine-tuning. You can try various options to fine better result. The value of frozen_layers can be change according `conf/fastspeech2_layers.txt` which is the model layer of fastspeech2.
@@ -180,7 +170,7 @@ Arguments:
## Get Started
-For Chinese data finetune, execute `./run.sh`. For English data finetune, execute `./run_en.sh`.
+For finetuning Chinese pretrained model, execute `./run.sh`. For finetuning English pretrained model, execute `./run_en.sh`. For finetuning Chinese-English Mixed pretrained model, execute `./run_mix.sh`.
Run the command below to
1. **source path**.
2. finetune the model.
diff --git a/examples/other/tts_finetune/tts3/local/extract_feature.py b/examples/other/tts_finetune/tts3/local/extract_feature.py
index 3277db531..daa3dacc7 100644
--- a/examples/other/tts_finetune/tts3/local/extract_feature.py
+++ b/examples/other/tts_finetune/tts3/local/extract_feature.py
@@ -56,13 +56,15 @@ def get_stats(pretrained_model_dir: Path):
def get_map(duration_file: Union[str, Path],
dump_dir: Path,
- pretrained_model_dir: Path):
+ pretrained_model_dir: Path,
+ replace_spkid: int=0):
"""get phone map and speaker map, save on dump_dir
Args:
duration_file (str): durantions.txt
dump_dir (Path): dump dir
pretrained_model_dir (Path): pretrained model dir
+ replace_spkid (int): replace spk id
"""
# copy phone map file from pretrained model path
phones_dict = dump_dir / "phone_id_map.txt"
@@ -75,14 +77,24 @@ def get_map(duration_file: Union[str, Path],
speakers = sorted(list(speaker_set))
num = len(speakers)
speaker_dict = dump_dir / "speaker_id_map.txt"
- with open(speaker_dict, 'w') as f, open(pretrained_model_dir /
- "speaker_id_map.txt", 'r') as fr:
- for i, spk in enumerate(speakers):
- f.write(spk + ' ' + str(i) + '\n')
+ spk_dict = {}
+ # get raw spkid-spk dict
+ with open(pretrained_model_dir / "speaker_id_map.txt", 'r') as fr:
for line in fr.readlines():
- spk_id = line.strip().split(" ")[-1]
- if int(spk_id) >= num:
- f.write(line)
+ spk = line.strip().split(" ")[0]
+ spk_id = line.strip().split(" ")[1]
+ spk_dict[spk_id] = spk
+
+ # replace spk on spkid-spk dict
+ assert replace_spkid + num - 1 < len(
+ spk_dict), "Please set correct replace spk id."
+ for i, spk in enumerate(speakers):
+ spk_dict[str(replace_spkid + i)] = spk
+
+ # write a new spk map file
+ with open(speaker_dict, 'w') as f:
+ for spk_id in spk_dict.keys():
+ f.write(spk_dict[spk_id] + ' ' + spk_id + '\n')
vocab_phones = {}
with open(phones_dict, 'rt') as f:
@@ -206,10 +218,11 @@ def extract_feature(duration_file: str,
config,
input_dir: Path,
dump_dir: Path,
- pretrained_model_dir: Path):
+ pretrained_model_dir: Path,
+ replace_spkid: int=0):
- sentences, vocab_phones, vocab_speaker = get_map(duration_file, dump_dir,
- pretrained_model_dir)
+ sentences, vocab_phones, vocab_speaker = get_map(
+ duration_file, dump_dir, pretrained_model_dir, replace_spkid)
mel_extractor, pitch_extractor, energy_extractor = get_extractor(config)
wav_files = sorted(list((input_dir).rglob("*.wav")))
@@ -315,6 +328,9 @@ if __name__ == '__main__':
default="./pretrained_models/fastspeech2_aishell3_ckpt_1.1.0",
help="Path to pretrained model")
+ parser.add_argument(
+ "--replace_spkid", type=int, default=0, help="replace spk id")
+
args = parser.parse_args()
input_dir = Path(args.input_dir).expanduser()
@@ -332,4 +348,5 @@ if __name__ == '__main__':
config=config,
input_dir=input_dir,
dump_dir=dump_dir,
- pretrained_model_dir=pretrained_model_dir)
+ pretrained_model_dir=pretrained_model_dir,
+ replace_spkid=args.replace_spkid)
diff --git a/examples/other/tts_finetune/tts3/run.sh b/examples/other/tts_finetune/tts3/run.sh
index 1faa2b46d..ed1705f8c 100755
--- a/examples/other/tts_finetune/tts3/run.sh
+++ b/examples/other/tts_finetune/tts3/run.sh
@@ -15,6 +15,7 @@ output_dir=./exp/default
lang=zh
ngpu=1
finetune_config=./conf/finetune.yaml
+replace_spkid=0
ckpt=snapshot_iter_96699
@@ -62,7 +63,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--duration_file="./durations.txt" \
--input_dir=${new_dir} \
--dump_dir=${dump_dir} \
- --pretrained_model_dir=${pretrained_model_dir}
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --replace_spkid=$replace_spkid
fi
# create finetune env
@@ -102,5 +104,5 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
--output_dir=./test_e2e/ \
--phones_dict=${dump_dir}/phone_id_map.txt \
--speaker_dict=${dump_dir}/speaker_id_map.txt \
- --spk_id=0
+ --spk_id=$replace_spkid
fi
diff --git a/examples/other/tts_finetune/tts3/run_en.sh b/examples/other/tts_finetune/tts3/run_en.sh
index e8551667e..765274e85 100755
--- a/examples/other/tts_finetune/tts3/run_en.sh
+++ b/examples/other/tts_finetune/tts3/run_en.sh
@@ -14,6 +14,7 @@ output_dir=./exp/default
lang=en
ngpu=1
finetune_config=./conf/finetune.yaml
+replace_spkid=0
ckpt=snapshot_iter_66300
@@ -61,7 +62,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--duration_file="./durations.txt" \
--input_dir=${new_dir} \
--dump_dir=${dump_dir} \
- --pretrained_model_dir=${pretrained_model_dir}
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --replace_spkid=$replace_spkid
fi
# create finetune env
@@ -101,5 +103,5 @@ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
--output_dir=./test_e2e/ \
--phones_dict=${dump_dir}/phone_id_map.txt \
--speaker_dict=${dump_dir}/speaker_id_map.txt \
- --spk_id=0
+ --spk_id=$replace_spkid
fi
diff --git a/examples/other/tts_finetune/tts3/run_mix.sh b/examples/other/tts_finetune/tts3/run_mix.sh
new file mode 100644
index 000000000..71008ef5b
--- /dev/null
+++ b/examples/other/tts_finetune/tts3/run_mix.sh
@@ -0,0 +1,110 @@
+#!/bin/bash
+
+set -e
+source path.sh
+
+
+input_dir=./input/SSB0005_mini
+newdir_name="newdir"
+new_dir=${input_dir}/${newdir_name}
+pretrained_model_dir=./pretrained_models/fastspeech2_mix_ckpt_1.2.0
+mfa_tools=./tools
+mfa_dir=./mfa_result
+dump_dir=./dump
+output_dir=./exp/default
+lang=zh
+ngpu=1
+finetune_config=./conf/finetune.yaml
+replace_spkid=174 # csmsc: 174, ljspeech: 175, aishell3: 0~173, vctk: 176
+
+ckpt=snapshot_iter_99300
+
+gpus=1
+CUDA_VISIBLE_DEVICES=${gpus}
+stage=0
+stop_stage=100
+
+
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
+
+# check oov
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "check oov"
+ python3 local/check_oov.py \
+ --input_dir=${input_dir} \
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --newdir_name=${newdir_name} \
+ --lang=${lang}
+fi
+
+# get mfa result
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "get mfa result"
+ python3 local/get_mfa_result.py \
+ --input_dir=${new_dir} \
+ --mfa_dir=${mfa_dir} \
+ --lang=${lang}
+fi
+
+# generate durations.txt
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "generate durations.txt"
+ python3 local/generate_duration.py \
+ --mfa_dir=${mfa_dir}
+fi
+
+# extract feature
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "extract feature"
+ python3 local/extract_feature.py \
+ --duration_file="./durations.txt" \
+ --input_dir=${new_dir} \
+ --dump_dir=${dump_dir} \
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --replace_spkid=$replace_spkid
+
+fi
+
+# create finetune env
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "create finetune env"
+ python3 local/prepare_env.py \
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --output_dir=${output_dir}
+fi
+
+# finetune
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "finetune..."
+ python3 local/finetune.py \
+ --pretrained_model_dir=${pretrained_model_dir} \
+ --dump_dir=${dump_dir} \
+ --output_dir=${output_dir} \
+ --ngpu=${ngpu} \
+ --epoch=100 \
+ --finetune_config=${finetune_config}
+fi
+
+# synthesize e2e
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "in hifigan syn_e2e"
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${pretrained_model_dir}/default.yaml \
+ --am_ckpt=${output_dir}/checkpoints/${ckpt}.pdz \
+ --am_stat=${pretrained_model_dir}/speech_stats.npy \
+ --voc=hifigan_aishell3 \
+ --voc_config=pretrained_models/hifigan_aishell3_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=pretrained_models/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \
+ --voc_stat=pretrained_models/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
+ --lang=mix \
+ --text=${BIN_DIR}/../sentences_mix.txt \
+ --output_dir=./test_e2e/ \
+ --phones_dict=${dump_dir}/phone_id_map.txt \
+ --speaker_dict=${dump_dir}/speaker_id_map.txt \
+ --spk_id=$replace_spkid
+fi
+
diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py
index de4c895f2..16feac5de 100644
--- a/paddlespeech/s2t/exps/wav2vec2/model.py
+++ b/paddlespeech/s2t/exps/wav2vec2/model.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Contains wav2vec2 model."""
import json
+import math
import os
import time
from collections import defaultdict
@@ -46,25 +47,20 @@ logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
- self.avg_train_loss = 0
+ self.avg_train_loss = 0.0
- def update_average(self, batch_index, loss, avg_loss):
+ def update_average(self, batch_index, loss):
"""Update running average of the loss.
Arguments
---------
+ batch_index : int
+ current batch index
loss : paddle.tensor
detached loss, a single float value.
- avg_loss : float
- current running average.
- Returns
- -------
- avg_loss : float
- The average loss.
"""
- if paddle.isfinite(loss):
- avg_loss -= avg_loss / (batch_index + 1)
- avg_loss += float(loss) / (batch_index + 1)
- return avg_loss
+ if math.isfinite(loss):
+ self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
+ self.avg_train_loss += loss / (batch_index + 1)
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
@@ -80,8 +76,8 @@ class Wav2Vec2ASRTrainer(Trainer):
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
- self.avg_train_loss = self.update_average(batch_index, loss,
- self.avg_train_loss)
+ # update self.avg_train_loss
+ self.update_average(batch_index, float(loss))
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
@@ -106,7 +102,7 @@ class Wav2Vec2ASRTrainer(Trainer):
self.lr_scheduler.step()
self.iteration += 1
- losses_np = {'loss': float(self.avg_train_loss) * train_conf.accum_grad}
+ losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)