From 2e825a4cec9d6b2b593b80f0752df91569aab5ad Mon Sep 17 00:00:00 2001 From: WongLaw Date: Wed, 15 Feb 2023 04:43:17 +0000 Subject: [PATCH] Cantonese FastSpeech2 e2e infer, test=tts --- examples/aishell3/tts3/local/synthesize.sh | 4 +- examples/canton/tts3/local/synthesize_e2e.sh | 53 ++++++++++ examples/canton/tts3/run.sh | 37 +------ paddlespeech/t2s/exps/sentences_canton.txt | 7 ++ paddlespeech/t2s/exps/syn_utils.py | 29 +++-- paddlespeech/t2s/exps/synthesize.py | 3 +- paddlespeech/t2s/exps/synthesize_e2e.py | 5 +- paddlespeech/t2s/frontend/canton_frontend.py | 106 +++++++++++++++++++ 8 files changed, 194 insertions(+), 50 deletions(-) create mode 100755 examples/canton/tts3/local/synthesize_e2e.sh create mode 100644 paddlespeech/t2s/exps/sentences_canton.txt create mode 100644 paddlespeech/t2s/frontend/canton_frontend.py diff --git a/examples/aishell3/tts3/local/synthesize.sh b/examples/aishell3/tts3/local/synthesize.sh index 9134e0426..0d288dbb8 100755 --- a/examples/aishell3/tts3/local/synthesize.sh +++ b/examples/aishell3/tts3/local/synthesize.sh @@ -12,7 +12,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_aishell3 \ + --am=fastspeech2_canton \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ --am_stat=dump/train/speech_stats.npy \ @@ -31,7 +31,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/../synthesize.py \ - --am=fastspeech2_aishell3 \ + --am=fastspeech2_canton \ --am_config=${config_path} \ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ --am_stat=dump/train/speech_stats.npy \ diff --git a/examples/canton/tts3/local/synthesize_e2e.sh b/examples/canton/tts3/local/synthesize_e2e.sh new file mode 100755 index 000000000..836772584 --- /dev/null +++ b/examples/canton/tts3/local/synthesize_e2e.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_canton \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --lang=canton \ + --text=${BIN_DIR}/../sentences_canton.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "in hifigan syn_e2e" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --am=fastspeech2_canton \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_aishell3 \ + --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences_canton.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --spk_id=0 \ + --inference_dir=${train_output_path}/inference + fi diff --git a/examples/canton/tts3/run.sh b/examples/canton/tts3/run.sh index 9e5c27a16..c1b4b3eb4 100755 --- a/examples/canton/tts3/run.sh +++ b/examples/canton/tts3/run.sh @@ -9,7 +9,7 @@ stop_stage=100 conf_path=conf/default.yaml train_output_path=exp/default -ckpt_name=snapshot_iter_112793.pdz +ckpt_name=snapshot_iter_280000.pdz # with the following command, you can choose the stage range you want to run # such as `./run.sh --stage 0 --stop-stage 0` @@ -34,37 +34,4 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # synthesize_e2e, vocoder is pwgan by default CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 -fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # inference with static model, vocoder is pwgan by default - CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 -fi - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - # install paddle2onnx - version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') - if [[ -z "$version" || ${version} != '1.0.0' ]]; then - pip install paddle2onnx==1.0.0 - fi - ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx fastspeech2_aishell3 - # considering the balance between speed and quality, we recommend that you use hifigan as vocoder - ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx pwgan_aishell3 - # ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_aishell3 - -fi - -# inference with onnxruntime, use fastspeech2 + pwgan by default -if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then - ./local/ort_predict.sh ${train_output_path} -fi - -if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then - ./local/export2lite.sh ${train_output_path} inference pdlite fastspeech2_aishell3 x86 - ./local/export2lite.sh ${train_output_path} inference pdlite pwgan_aishell3 x86 - # ./local/export2lite.sh ${train_output_path} inference pdlite hifigan_aishell3 x86 -fi - -if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then - CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 -fi +fi \ No newline at end of file diff --git a/paddlespeech/t2s/exps/sentences_canton.txt b/paddlespeech/t2s/exps/sentences_canton.txt new file mode 100644 index 000000000..dcc6019b9 --- /dev/null +++ b/paddlespeech/t2s/exps/sentences_canton.txt @@ -0,0 +1,7 @@ +001 白云山爬过一次嘅,好远啊,爬上去都成两个钟 +002 睇书咯,番屋企,而家好多人好少睇书噶喎 +003 因为如果唔考试嘅话,工资好低噶 +004 冇固定噶,你中意休边日就边日噶 +005 即系太迟嘅话咧,落班太迟嘅话就喺出边食啲咯 +006 是非有公理,慎言莫冒犯别人 +007 遇上冷风雨,休太认真 diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 491edda30..b12b088e9 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -33,6 +33,7 @@ from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.canton_frontend import CantonFrontend from paddlespeech.t2s.frontend.mix_frontend import MixFrontend from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore @@ -117,6 +118,8 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): sentence = " ".join(items[1:]) elif lang == 'mix': sentence = " ".join(items[1:]) + elif lang == 'canton': + sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) return sentences @@ -132,8 +135,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], converters = {} if am_name == 'fastspeech2': fields = ["utt_id", "text"] - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", "mix", + "canton"} and speaker_dict is not None: print("multiple speaker fastspeech2!") fields += ["spk_id"] elif voice_cloning: @@ -177,8 +180,8 @@ def get_dev_dataloader(dev_metadata: List[Dict[str, Any]], converters = {} if am_name == 'fastspeech2': fields = ["utt_id", "text"] - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", "mix", + "canton"} and speaker_dict is not None: print("multiple speaker fastspeech2!") collate_fn = fastspeech2_multi_spk_batch_fn_static fields += ["spk_id"] @@ -266,6 +269,8 @@ def get_frontend(lang: str='zh', phone_vocab_path=phones_dict, tone_vocab_path=tones_dict, use_rhy=use_rhy) + if lang == 'canton': + frontend = CantonFrontend(phone_vocab_path=phones_dict) elif lang == 'en': frontend = English(phone_vocab_path=phones_dict) elif lang == 'mix': @@ -302,6 +307,10 @@ def run_frontend(frontend: object, if get_tone_ids: tone_ids = input_ids["tone_ids"] outs.update({'tone_ids': tone_ids}) + if lang == 'canton': + input_ids = frontend.get_input_ids( + text, merge_sentences=merge_sentences, to_tensor=to_tensor) + phone_ids = input_ids["phone_ids"] elif lang == 'en': input_ids = frontend.get_input_ids( text, merge_sentences=merge_sentences, to_tensor=to_tensor) @@ -311,7 +320,7 @@ def run_frontend(frontend: object, text, merge_sentences=merge_sentences, to_tensor=to_tensor) phone_ids = input_ids["phone_ids"] else: - print("lang should in {'zh', 'en', 'mix'}!") + print("lang should in {'zh', 'en', 'mix', 'canton'}!") outs.update({'phone_ids': phone_ids}) return outs @@ -411,8 +420,8 @@ def am_to_static(am_inference, am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", "mix", + "canton"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -424,8 +433,8 @@ def am_to_static(am_inference, am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk", - "mix"} and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", "mix", + "canton"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -575,7 +584,7 @@ def get_am_output( get_tone_ids = False if am_name == 'speedyspeech': get_tone_ids = True - if am_dataset in {"aishell3", "vctk", "mix"} and speaker_dict: + if am_dataset in {"aishell3", "vctk", "mix", "canton"} and speaker_dict: get_spk_id = True spk_id = np.array([spk_id]) diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index a8e18150e..70e52244f 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -136,7 +136,8 @@ def parse_args(): choices=[ 'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc', - 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix' + 'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix', + 'fastspeech2_canton' ], help='Choose acoustic model type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 625002477..3b87d9e16 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -119,7 +119,7 @@ def evaluate(args): # acoustic model if am_name == 'fastspeech2': # multi speaker - if am_dataset in {"aishell3", "vctk", "mix"}: + if am_dataset in {"aishell3", "vctk", "mix", "canton"}: spk_id = paddle.to_tensor(args.spk_id) mel = am_inference(part_phone_ids, spk_id) else: @@ -167,7 +167,8 @@ def parse_args(): choices=[ 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', - 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix' + 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix', + 'fastspeech2_canton' ], help='Choose acoustic model type of tts task.') parser.add_argument( diff --git a/paddlespeech/t2s/frontend/canton_frontend.py b/paddlespeech/t2s/frontend/canton_frontend.py new file mode 100644 index 000000000..9891c9447 --- /dev/null +++ b/paddlespeech/t2s/frontend/canton_frontend.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List + +import numpy as np +import paddle +import ToJyutping + +from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer + +INITIALS = [ + 'p', 'b', 't', 'd', 'ts', 'dz', 'k', 'g', 'kw', 'gw', 'f', 'h', 'l', 'm', + 'ng', 'n', 's', 'y', 'w', 'c', 'z', 'j' +] +INITIALS += ['sp', 'spl', 'spn', 'sil'] + + +def get_lines(cantons): + phones = [] + for canton in cantons: + for consonant in INITIALS: + if canton.startswith(consonant): + c, v = canton[:len(consonant)], canton[len(consonant):] + phones = phones + [c, v] + return phones + + +class CantonFrontend(): + def __init__(self, phone_vocab_path: str): + self.text_normalizer = TextNormalizer() + self.punc = ":,;。?!“”‘’':,;.?!" + + self.vocab_phones = {} + if phone_vocab_path: + with open(phone_vocab_path, 'rt', encoding='utf-8') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + self.vocab_phones[phn] = int(id) + + # if merge_sentences, merge all sentences into one phone sequence + def _g2p(self, sentences: List[str], + merge_sentences: bool=True) -> List[List[str]]: + phones_list = [] + for sentence in sentences: + phones_str = ToJyutping.get_jyutping_text(sentence) + phones_split = get_lines(phones_str.split(' ')) + phones_list.append(phones_split) + return phones_list + + def _p2id(self, phonemes: List[str]) -> np.ndarray: + # replace unk phone with sp + phonemes = [ + phn if phn in self.vocab_phones else "sp" for phn in phonemes + ] + phone_ids = [self.vocab_phones[item] for item in phonemes] + return np.array(phone_ids, np.int64) + + def get_phonemes(self, + sentence: str, + merge_sentences: bool=True, + print_info: bool=False) -> List[List[str]]: + sentences = self.text_normalizer.normalize(sentence) + phonemes = self._g2p(sentences, merge_sentences=merge_sentences) + if print_info: + print("----------------------------") + print("text norm results:") + print(sentences) + print("----------------------------") + print("g2p results:") + print(phonemes) + print("----------------------------") + return phonemes + + def get_input_ids(self, + sentence: str, + merge_sentences: bool=True, + print_info: bool=False, + to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: + + phonemes = self.get_phonemes( + sentence, merge_sentences=merge_sentences, print_info=print_info) + result = {} + temp_phone_ids = [] + + for phones in phonemes: + if phones: + phone_ids = self._p2id(phones) + # if use paddle.to_tensor() in onnxruntime, the first time will be too low + if to_tensor: + phone_ids = paddle.to_tensor(phone_ids) + temp_phone_ids.append(phone_ids) + if temp_phone_ids: + result["phone_ids"] = temp_phone_ids + return result