From 15a00431ad2fcdce3938b372b9711ce9b3bad324 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 17 Jan 2022 12:07:53 +0000 Subject: [PATCH 1/5] rm s2t:augmentation.md, test=doc --- docs/source/asr/augmentation.md | 40 --------------------------------- docs/source/index.rst | 1 - 2 files changed, 41 deletions(-) delete mode 100644 docs/source/asr/augmentation.md diff --git a/docs/source/asr/augmentation.md b/docs/source/asr/augmentation.md deleted file mode 100644 index 8e65cb19..00000000 --- a/docs/source/asr/augmentation.md +++ /dev/null @@ -1,40 +0,0 @@ -# Data Augmentation Pipeline - -Data augmentation has often been a highly effective technique to boost deep learning performance. We augment our speech data by synthesizing new audios with small random perturbation (label-invariant transformation) added upon raw audios. You don't have to do the syntheses on your own, as it is already embedded into the data provider and is done on the fly, randomly for each epoch during training. - -Six optional augmentation components are provided to be selected, configured, and inserted into the processing pipeline. - -* Audio - - Volume Perturbation - - Speed Perturbation - - Shifting Perturbation - - Online Bayesian normalization - - Noise Perturbation (need background noise audio files) - - Impulse Response (need impulse audio files) - -* Feature - - SpecAugment - - Adaptive SpecAugment - -To inform the trainer of what augmentation components are needed and what their processing orders are, it is required to prepare in advance an *augmentation configuration file* in [JSON](http://www.json.org/) format. For example: - -``` -[{ - "type": "speed", - "params": {"min_speed_rate": 0.95, - "max_speed_rate": 1.05}, - "prob": 0.6 -}, -{ - "type": "shift", - "params": {"min_shift_ms": -5, - "max_shift_ms": 5}, - "prob": 0.8 -}] -``` - -When the `augment_conf_file` argument is set to the path of the above example configuration file, every audio clip in every epoch will be processed: with 60% of chance, it will first be speed perturbed with a uniformly random sampled speed-rate between 0.95 and 1.05, and then with 80% of chance it will be shifted in time with a randomly sampled offset between -5 ms and 5 ms. Finally, this newly synthesized audio clip will be fed into the feature extractor for further training. - -For other configuration examples, please refer to `examples/conf/augmentation.example.json`. - -Be careful when utilizing the data augmentation technique, as improper augmentation will harm the training, due to the enlarged train-test gap. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5bbc9319..bf675b4b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,7 +27,6 @@ Contents asr/models_introduction asr/data_preparation - asr/augmentation asr/feature_list asr/ngram_lm From 1a9e59612a9124ffc3f97aae57f4d24792cdb9cf Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 18 Jan 2022 03:53:27 +0000 Subject: [PATCH 2/5] fix fastspeech2 multi speaker to static, test=tts --- examples/aishell3/tts3/README.md | 3 +- .../aishell3/tts3/local/synthesize_e2e.sh | 3 +- examples/vctk/tts3/README.md | 9 +-- examples/vctk/tts3/local/synthesize_e2e.sh | 3 +- paddlespeech/t2s/exps/inference.py | 66 ++++++++++++++++--- paddlespeech/t2s/exps/synthesize_e2e.py | 13 +++- .../t2s/models/fastspeech2/fastspeech2.py | 2 +- 7 files changed, 78 insertions(+), 21 deletions(-) diff --git a/examples/aishell3/tts3/README.md b/examples/aishell3/tts3/README.md index 2538e8f9..281ad836 100644 --- a/examples/aishell3/tts3/README.md +++ b/examples/aishell3/tts3/README.md @@ -257,6 +257,7 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=exp/default/test_e2e \ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=exp/default/inference ``` diff --git a/examples/aishell3/tts3/local/synthesize_e2e.sh b/examples/aishell3/tts3/local/synthesize_e2e.sh index d0d92585..60e1a5ce 100755 --- a/examples/aishell3/tts3/local/synthesize_e2e.sh +++ b/examples/aishell3/tts3/local/synthesize_e2e.sh @@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=${train_output_path}/inference diff --git a/examples/vctk/tts3/README.md b/examples/vctk/tts3/README.md index 74c1086a..157949d1 100644 --- a/examples/vctk/tts3/README.md +++ b/examples/vctk/tts3/README.md @@ -240,13 +240,14 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --am_ckpt=fastspeech2_nosil_vctk_ckpt_0.5/snapshot_iter_66200.pdz \ --am_stat=fastspeech2_nosil_vctk_ckpt_0.5/speech_stats.npy \ --voc=pwgan_vctk \ - --voc_config=pwg_vctk_ckpt_0.5/pwg_default.yaml \ - --voc_ckpt=pwg_vctk_ckpt_0.5/pwg_snapshot_iter_1000000.pdz \ - --voc_stat=pwg_vctk_ckpt_0.5/pwg_stats.npy \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ --lang=en \ --text=${BIN_DIR}/../sentences_en.txt \ --output_dir=exp/default/test_e2e \ --phones_dict=fastspeech2_nosil_vctk_ckpt_0.5/phone_id_map.txt \ --speaker_dict=fastspeech2_nosil_vctk_ckpt_0.5/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=exp/default/inference ``` diff --git a/examples/vctk/tts3/local/synthesize_e2e.sh b/examples/vctk/tts3/local/synthesize_e2e.sh index 51bb9e19..60d56d1c 100755 --- a/examples/vctk/tts3/local/synthesize_e2e.sh +++ b/examples/vctk/tts3/local/synthesize_e2e.sh @@ -20,4 +20,5 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --output_dir=${train_output_path}/test_e2e \ --phones_dict=dump/phone_id_map.txt \ --speaker_dict=dump/speaker_id_map.txt \ - --spk_id=0 + --spk_id=0 \ + --inference_dir=${train_output_path}/inference diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index e1d5306c..2c9b51f9 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,9 +14,11 @@ import argparse from pathlib import Path +import numpy import soundfile as sf from paddle import inference +from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -29,20 +31,38 @@ def main(): '--am', type=str, default='fastspeech2_csmsc', - choices=['speedyspeech_csmsc', 'fastspeech2_csmsc'], + choices=[ + 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_aishell3', + 'fastspeech2_vctk' + ], help='Choose acoustic model type of tts task.') parser.add_argument( "--phones_dict", type=str, default=None, help="phone vocabulary file.") parser.add_argument( "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') # voc parser.add_argument( '--voc', type=str, default='pwgan_csmsc', - choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], + choices=[ + 'pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc', 'pwgan_aishell3', + 'pwgan_vctk' + ], help='Choose vocoder type of tts task.') # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') parser.add_argument( "--text", type=str, @@ -53,8 +73,12 @@ def main(): args, _ = parser.parse_known_args() - frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) + # frontend + if args.lang == 'zh': + frontend = Frontend( + phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) + elif args.lang == 'en': + frontend = English(phone_vocab_path=args.phones_dict) print("frontend done!") # model: {model_name}_{dataset} @@ -83,30 +107,52 @@ def main(): print("in new inference") + # construct dataset for evaluation + sentences = [] with open(args.text, 'rt') as f: for line in f: items = line.strip().split() utt_id = items[0] - sentence = "".join(items[1:]) + if args.lang == 'zh': + sentence = "".join(items[1:]) + elif args.lang == 'en': + sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + get_spk_id = True + spk_id = numpy.array([args.spk_id]) am_input_names = am_predictor.get_input_names() - + print("am_input_names:", am_input_names) + merge_sentences = True for utt_id, sentence in sentences: - input_ids = frontend.get_input_ids( - sentence, merge_sentences=True, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + if get_tone_ids: tone_ids = input_ids["tone_ids"] tones = tone_ids[0].numpy() tones_handle = am_predictor.get_input_handle(am_input_names[1]) tones_handle.reshape(tones.shape) tones_handle.copy_from_cpu(tones) - + if get_spk_id: + spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) + spk_id_handle.reshape(spk_id.shape) + spk_id_handle.copy_from_cpu(spk_id) phones = phone_ids[0].numpy() phones_handle = am_predictor.get_input_handle(am_input_names[0]) phones_handle.reshape(phones.shape) diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 15ed1e4d..9b503213 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -159,9 +159,16 @@ def evaluate(args): # acoustic model if am_name == 'fastspeech2': if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - print( - "Haven't test dygraph to static for multi speaker fastspeech2 now!" - ) + am_inference = jit.to_static( + am_inference, + input_spec=[ + InputSpec([-1], dtype=paddle.int64), + InputSpec([1], dtype=paddle.int64) + ]) + paddle.jit.save(am_inference, + os.path.join(args.inference_dir, args.am)) + am_inference = paddle.jit.load( + os.path.join(args.inference_dir, args.am)) else: am_inference = jit.to_static( am_inference, diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 6bb651a0..dc136ffd 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -781,7 +781,7 @@ class FastSpeech2(nn.Layer): elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( - shape=[-1, hs.shape[1], -1]) + shape=[-1, paddle.shape(hs)[1], -1]) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) else: raise NotImplementedError("support only add or concat.") From 41d24337cb52555b28bd3a72ab1334ea67dac352 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 18 Jan 2022 04:02:35 +0000 Subject: [PATCH 3/5] fix fastspeech2 multi speaker to static, test=tts --- examples/aishell3/tts3/local/inference.sh | 19 +++++++++++++++++++ examples/vctk/tts3/local/inference.sh | 20 ++++++++++++++++++++ paddlespeech/t2s/exps/inference.py | 1 + 3 files changed, 40 insertions(+) create mode 100755 examples/aishell3/tts3/local/inference.sh create mode 100755 examples/vctk/tts3/local/inference.sh diff --git a/examples/aishell3/tts3/local/inference.sh b/examples/aishell3/tts3/local/inference.sh new file mode 100755 index 00000000..3b03b53c --- /dev/null +++ b/examples/aishell3/tts3/local/inference.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +train_output_path=$1 + +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=fastspeech2_aishell3 \ + --voc=pwgan_aishell3 \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 +fi + diff --git a/examples/vctk/tts3/local/inference.sh b/examples/vctk/tts3/local/inference.sh new file mode 100755 index 00000000..caef89d8 --- /dev/null +++ b/examples/vctk/tts3/local/inference.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +train_output_path=$1 + +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=fastspeech2_vctk \ + --voc=pwgan_vctk \ + --text=${BIN_DIR}/../sentences_en.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --lang=en +fi + diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 2c9b51f9..37afd0ab 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -120,6 +120,7 @@ def main(): sentences.append((utt_id, sentence)) get_tone_ids = False + get_spk_id = False if am_name == 'speedyspeech': get_tone_ids = True if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: From 7c3bcec59a41ae9080da9fa03c6ea456ec321a0f Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Tue, 18 Jan 2022 19:14:34 +0800 Subject: [PATCH 4/5] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index cca1cb53..91a3bdd8 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,6 @@ Normally, [Speech SoTA](https://paperswithcode.com/area/speech), [Audio SoTA](ht - [Automatic Speech Recognition](./docs/source/asr/quick_start.md) - [Introduction](./docs/source/asr/models_introduction.md) - [Data Preparation](./docs/source/asr/data_preparation.md) - - [Data Augmentation](./docs/source/asr/augmentation.md) - [Ngram LM](./docs/source/asr/ngram_lm.md) - [Text-to-Speech](./docs/source/tts/quick_start.md) - [Introduction](./docs/source/tts/models_introduction.md) From af6cb9043450af6c3704c614c03a72c362d493a5 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Tue, 18 Jan 2022 19:14:58 +0800 Subject: [PATCH 5/5] Update README_cn.md --- README_cn.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README_cn.md b/README_cn.md index ddf189c3..b542d9ce 100644 --- a/README_cn.md +++ b/README_cn.md @@ -468,7 +468,6 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声 - [语音识别自定义训练](./docs/source/asr/quick_start.md) - [简介](./docs/source/asr/models_introduction.md) - [数据准备](./docs/source/asr/data_preparation.md) - - [数据增强](./docs/source/asr/augmentation.md) - [Ngram 语言模型](./docs/source/asr/ngram_lm.md) - [语音合成自定义训练](./docs/source/tts/quick_start.md) - [简介](./docs/source/tts/models_introduction.md)