diff --git a/demos/speech_server/conf/application.yaml b/demos/speech_server/conf/application.yaml index 9c171c470..b5ee80095 100644 --- a/demos/speech_server/conf/application.yaml +++ b/demos/speech_server/conf/application.yaml @@ -61,7 +61,7 @@ tts_python: phones_dict: tones_dict: speaker_dict: - spk_id: 0 + # voc (vocoder) choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', # 'pwgan_vctk', 'mb_melgan_csmsc', 'style_melgan_csmsc', @@ -87,7 +87,7 @@ tts_inference: phones_dict: tones_dict: speaker_dict: - spk_id: 0 + am_predictor_conf: device: # set 'gpu:id' or 'cpu' diff --git a/demos/streaming_tts_server/conf/tts_online_application.yaml b/demos/streaming_tts_server/conf/tts_online_application.yaml index e617912fe..f5ec9dc8e 100644 --- a/demos/streaming_tts_server/conf/tts_online_application.yaml +++ b/demos/streaming_tts_server/conf/tts_online_application.yaml @@ -29,7 +29,7 @@ tts_online: phones_dict: tones_dict: speaker_dict: - spk_id: 0 + # voc (vocoder) choices=['mb_melgan_csmsc, hifigan_csmsc'] # Both mb_melgan_csmsc and hifigan_csmsc support streaming voc inference @@ -70,7 +70,6 @@ tts_online-onnx: phones_dict: tones_dict: speaker_dict: - spk_id: 0 am_sample_rate: 24000 am_sess_conf: device: "cpu" # set 'gpu:id' or 'cpu' diff --git a/demos/streaming_tts_server/conf/tts_online_ws_application.yaml b/demos/streaming_tts_server/conf/tts_online_ws_application.yaml index 329f882cc..c65633917 100644 --- a/demos/streaming_tts_server/conf/tts_online_ws_application.yaml +++ b/demos/streaming_tts_server/conf/tts_online_ws_application.yaml @@ -29,7 +29,7 @@ tts_online: phones_dict: tones_dict: speaker_dict: - spk_id: 0 + # voc (vocoder) choices=['mb_melgan_csmsc, hifigan_csmsc'] # Both mb_melgan_csmsc and hifigan_csmsc support streaming voc inference @@ -70,7 +70,6 @@ tts_online-onnx: phones_dict: tones_dict: speaker_dict: - spk_id: 0 am_sample_rate: 24000 am_sess_conf: device: "cpu" # set 'gpu:id' or 'cpu' diff --git a/docs/source/tts/tts_papers.md b/docs/source/tts/tts_papers.md index 681b21066..f3ca1b624 100644 --- a/docs/source/tts/tts_papers.md +++ b/docs/source/tts/tts_papers.md @@ -5,6 +5,7 @@ - [Disambiguation of Chinese Polyphones in an End-to-End Framework with Semantic Features Extracted by Pre-trained BERT](https://www1.se.cuhk.edu.hk/~hccl/publications/pub/201909_INTERSPEECH_DongyangDAI.pdf) - [Polyphone Disambiguation in Mandarin Chinese with Semi-Supervised Learning](https://www.isca-speech.org/archive/pdfs/interspeech_2021/shi21d_interspeech.pdf) * github: https://github.com/PaperMechanica/SemiPPL +- [WikipediaHomographData](https://github.com/google-research-datasets/WikipediaHomographData) ### Text Normalization #### English - [applenob/text_normalization](https://github.com/applenob/text_normalization) diff --git a/examples/aishell3/ernie_sat/README.md b/examples/aishell3/ernie_sat/README.md index 707ee1381..eb867ab75 100644 --- a/examples/aishell3/ernie_sat/README.md +++ b/examples/aishell3/ernie_sat/README.md @@ -1,11 +1,10 @@ -# ERNIE-SAT with AISHELL3 dataset +# ERNIE-SAT with VCTK dataset +ERNIE-SAT speech-text joint pretraining framework, which achieves SOTA results in cross-lingual multi-speaker speech synthesis and cross-lingual speech editing tasks, It can be applied to a series of scenarios such as Speech Editing, personalized Speech Synthesis, and Voice Cloning. -ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 - -## 模型框架 -ERNIE-SAT 中我们提出了两项创新: -- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 -- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 +## Model Framework +In ERNIE-SAT, we propose two innovations: +- In the pretraining process, the phonemes corresponding to Chinese and English are used as input to achieve cross-language and personalized soft phoneme mapping +- The joint mask learning of speech and text is used to realize the alignment of speech and text

diff --git a/examples/aishell3/ernie_sat/local/synthesize_e2e.sh b/examples/aishell3/ernie_sat/local/synthesize_e2e.sh index b33e8ca09..77b353b52 100755 --- a/examples/aishell3/ernie_sat/local/synthesize_e2e.sh +++ b/examples/aishell3/ernie_sat/local/synthesize_e2e.sh @@ -13,9 +13,9 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/synthesize_e2e.py \ --task_name=synthesize \ - --wav_path=source/SSB03540307.wav\ - --old_str='请播放歌曲小苹果。' \ - --new_str='歌曲真好听。' \ + --wav_path=source/SSB03540307.wav \ + --old_str='请播放歌曲小苹果' \ + --new_str='歌曲真好听' \ --source_lang=zh \ --target_lang=zh \ --erniesat_config=${config_path} \ diff --git a/examples/aishell3_vctk/ernie_sat/README.md b/examples/aishell3_vctk/ernie_sat/README.md index a849488d5..d55af6756 100644 --- a/examples/aishell3_vctk/ernie_sat/README.md +++ b/examples/aishell3_vctk/ernie_sat/README.md @@ -1,11 +1,10 @@ -# ERNIE-SAT with AISHELL3 and VCTK dataset +# ERNIE-SAT with VCTK dataset +ERNIE-SAT speech-text joint pretraining framework, which achieves SOTA results in cross-lingual multi-speaker speech synthesis and cross-lingual speech editing tasks, It can be applied to a series of scenarios such as Speech Editing, personalized Speech Synthesis, and Voice Cloning. -ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 - -## 模型框架 -ERNIE-SAT 中我们提出了两项创新: -- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 -- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 +## Model Framework +In ERNIE-SAT, we propose two innovations: +- In the pretraining process, the phonemes corresponding to Chinese and English are used as input to achieve cross-language and personalized soft phoneme mapping +- The joint mask learning of speech and text is used to realize the alignment of speech and text

diff --git a/examples/aishell3_vctk/ernie_sat/local/synthesize_e2e.sh b/examples/aishell3_vctk/ernie_sat/local/synthesize_e2e.sh index c30af6e85..446ac8791 100755 --- a/examples/aishell3_vctk/ernie_sat/local/synthesize_e2e.sh +++ b/examples/aishell3_vctk/ernie_sat/local/synthesize_e2e.sh @@ -15,7 +15,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${BIN_DIR}/synthesize_e2e.py \ --task_name=synthesize \ --wav_path=source/p243_313.wav \ - --old_str='For that reason cover should not be given.' \ + --old_str='For that reason cover should not be given' \ --new_str='今天天气很好' \ --source_lang=en \ --target_lang=zh \ @@ -36,8 +36,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${BIN_DIR}/synthesize_e2e.py \ --task_name=synthesize \ --wav_path=source/SSB03540307.wav \ - --old_str='请播放歌曲小苹果。' \ - --new_str="Thank you!" \ + --old_str='请播放歌曲小苹果' \ + --new_str="Thank you" \ --source_lang=zh \ --target_lang=en \ --erniesat_config=${config_path} \ diff --git a/examples/other/tts_finetune/tts3/README.md b/examples/other/tts_finetune/tts3/README.md index 1ad30328b..192ee7ff4 100644 --- a/examples/other/tts_finetune/tts3/README.md +++ b/examples/other/tts_finetune/tts3/README.md @@ -75,6 +75,15 @@ When "Prepare" done. The structure of the current directory is listed below. ``` +### Set finetune.yaml +`finetune.yaml` contains some configurations for fine-tuning. You can try various options to fine better result. +Arguments: + - `batch_size`: finetune batch size. Default: -1, means 64 which same to pretrained model + - `learning_rate`: learning rate. Default: 0.0001 + - `num_snapshots`: number of save models. Default: -1, means 5 which same to pretrained model + - `frozen_layers`: frozen layers. must be a list. If you don't want to frozen any layer, set []. + + ## Get Started Run the command below to diff --git a/examples/other/tts_finetune/tts3/finetune.py b/examples/other/tts_finetune/tts3/finetune.py index 0f060b44d..207e2dbc5 100644 --- a/examples/other/tts_finetune/tts3/finetune.py +++ b/examples/other/tts_finetune/tts3/finetune.py @@ -14,6 +14,7 @@ import argparse import os from pathlib import Path +from typing import List from typing import Union import yaml @@ -21,10 +22,10 @@ from local.check_oov import get_check_result from local.extract import extract_feature from local.label_process import get_single_label from local.prepare_env import generate_finetune_env +from local.train import train_sp from paddle import distributed as dist from yacs.config import CfgNode -from paddlespeech.t2s.exps.fastspeech2.train import train_sp from utils.gen_duration_from_textgrid import gen_duration_from_textgrid DICT_EN = 'tools/aligner/cmudict-0.7b' @@ -38,15 +39,24 @@ os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] class TrainArgs(): - def __init__(self, ngpu, config_file, dump_dir: Path, output_dir: Path): + def __init__(self, + ngpu, + config_file, + dump_dir: Path, + output_dir: Path, + frozen_layers: List[str]): + # config: fastspeech2 config file. self.config = str(config_file) self.train_metadata = str(dump_dir / "train/norm/metadata.jsonl") self.dev_metadata = str(dump_dir / "dev/norm/metadata.jsonl") + # model output dir. self.output_dir = str(output_dir) self.ngpu = ngpu self.phones_dict = str(dump_dir / "phone_id_map.txt") self.speaker_dict = str(dump_dir / "speaker_id_map.txt") self.voice_cloning = False + # frozen layers + self.frozen_layers = frozen_layers def get_mfa_result( @@ -122,12 +132,11 @@ if __name__ == '__main__': "--ngpu", type=int, default=2, help="if ngpu=0, use cpu.") parser.add_argument("--epoch", type=int, default=100, help="finetune epoch") - parser.add_argument( - "--batch_size", - type=int, - default=-1, - help="batch size, default -1 means same as pretrained model") + "--finetune_config", + type=str, + default="./finetune.yaml", + help="Path to finetune config file") args = parser.parse_args() @@ -147,8 +156,14 @@ if __name__ == '__main__': with open(config_file) as f: config = CfgNode(yaml.safe_load(f)) config.max_epoch = config.max_epoch + args.epoch - if args.batch_size > 0: - config.batch_size = args.batch_size + + with open(args.finetune_config) as f2: + finetune_config = CfgNode(yaml.safe_load(f2)) + config.batch_size = finetune_config.batch_size if finetune_config.batch_size > 0 else config.batch_size + config.optimizer.learning_rate = finetune_config.learning_rate if finetune_config.learning_rate > 0 else config.optimizer.learning_rate + config.num_snapshots = finetune_config.num_snapshots if finetune_config.num_snapshots > 0 else config.num_snapshots + frozen_layers = finetune_config.frozen_layers + assert type(frozen_layers) == list, "frozen_layers should be set a list." if args.lang == 'en': lexicon_file = DICT_EN @@ -158,6 +173,13 @@ if __name__ == '__main__': mfa_phone_file = MFA_PHONE_ZH else: print('please input right lang!!') + + print(f"finetune max_epoch: {config.max_epoch}") + print(f"finetune batch_size: {config.batch_size}") + print(f"finetune learning_rate: {config.optimizer.learning_rate}") + print(f"finetune num_snapshots: {config.num_snapshots}") + print(f"finetune frozen_layers: {frozen_layers}") + am_phone_file = pretrained_model_dir / "phone_id_map.txt" label_file = input_dir / "labels.txt" @@ -181,7 +203,8 @@ if __name__ == '__main__': generate_finetune_env(output_dir, pretrained_model_dir) # create a new args for training - train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir) + train_args = TrainArgs(args.ngpu, config_file, dump_dir, output_dir, + frozen_layers) # finetune models # dispatch diff --git a/examples/other/tts_finetune/tts3/finetune.yaml b/examples/other/tts_finetune/tts3/finetune.yaml new file mode 100644 index 000000000..374a69f3d --- /dev/null +++ b/examples/other/tts_finetune/tts3/finetune.yaml @@ -0,0 +1,12 @@ +########################################################### +# PARAS SETTING # +########################################################### +# Set to -1 to indicate that the parameter is the same as the pretrained model configuration + +batch_size: -1 +learning_rate: 0.0001 # learning rate +num_snapshots: -1 + +# frozen_layers should be a list +# if you don't need to freeze, set frozen_layers to [] +frozen_layers: ["encoder", "duration_predictor"] diff --git a/examples/other/tts_finetune/tts3/local/extract.py b/examples/other/tts_finetune/tts3/local/extract.py index edd92420b..630b58ce3 100644 --- a/examples/other/tts_finetune/tts3/local/extract.py +++ b/examples/other/tts_finetune/tts3/local/extract.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math import os from operator import itemgetter from pathlib import Path @@ -211,9 +210,9 @@ def extract_feature(duration_file: str, mel_extractor, pitch_extractor, energy_extractor = get_extractor(config) wav_files = sorted(list((input_dir).rglob("*.wav"))) - # split data into 3 sections, train: 80%, dev: 10%, test: 10% - num_train = math.ceil(len(wav_files) * 0.8) - num_dev = math.ceil(len(wav_files) * 0.1) + # split data into 3 sections, train: len(wav_files) - 2, dev: 1, test: 1 + num_train = len(wav_files) - 2 + num_dev = 1 print(num_train, num_dev) train_wav_files = wav_files[:num_train] diff --git a/examples/other/tts_finetune/tts3/local/train.py b/examples/other/tts_finetune/tts3/local/train.py new file mode 100644 index 000000000..d065ae593 --- /dev/null +++ b/examples/other/tts_finetune/tts3/local/train.py @@ -0,0 +1,178 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import shutil +from pathlib import Path +from typing import List + +import jsonlines +import numpy as np +import paddle +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_multi_spk_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import fastspeech2_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Evaluator +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Updater +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import build_optimizers +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def freeze_layer(model, layers: List[str]): + """freeze layers + + Args: + layers (List[str]): frozen layers + """ + for layer in layers: + for param in eval("model." + layer + ".parameters()"): + param.trainable = False + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + fields = [ + "text", "text_lengths", "speech", "speech_lengths", "durations", + "pitch", "energy" + ] + converters = {"speech": np.load, "pitch": np.load, "energy": np.load} + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker fastspeech2!") + collate_fn = fastspeech2_multi_spk_batch_fn + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + elif args.voice_cloning: + print("Training voice cloning!") + collate_fn = fastspeech2_multi_spk_batch_fn + fields += ["spk_emb"] + converters["spk_emb"] = np.load + else: + print("single speaker fastspeech2!") + collate_fn = fastspeech2_single_spk_batch_fn + print("spk_num:", spk_num) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_mels + model = FastSpeech2( + idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"]) + + # freeze layer + if args.frozen_layers != []: + freeze_layer(model, args.frozen_layers) + + if world_size > 1: + model = DataParallel(model) + print("model done!") + + optimizer = build_optimizers(model, **config["optimizer"]) + print("optimizer done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = FastSpeech2Updater( + model=model, + optimizer=optimizer, + dataloader=train_dataloader, + output_dir=output_dir, + **config["updater"]) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + evaluator = FastSpeech2Evaluator( + model, dev_dataloader, output_dir=output_dir, **config["updater"]) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + trainer.run() diff --git a/examples/other/tts_finetune/tts3/run.sh b/examples/other/tts_finetune/tts3/run.sh index 9bb7ec6f0..9c877e642 100755 --- a/examples/other/tts_finetune/tts3/run.sh +++ b/examples/other/tts_finetune/tts3/run.sh @@ -10,11 +10,12 @@ mfa_dir=./mfa_result dump_dir=./dump output_dir=./exp/default lang=zh -ngpu=2 +ngpu=1 +finetune_config=./finetune.yaml -ckpt=snapshot_iter_96600 +ckpt=snapshot_iter_96699 -gpus=0,1 +gpus=1 CUDA_VISIBLE_DEVICES=${gpus} stage=0 stop_stage=100 @@ -35,7 +36,8 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --output_dir=${output_dir} \ --lang=${lang} \ --ngpu=${ngpu} \ - --epoch=100 + --epoch=100 \ + --finetune_config=${finetune_config} fi @@ -54,7 +56,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --voc_stat=pretrained_models/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ - --output_dir=./test_e2e \ + --output_dir=./test_e2e/ \ --phones_dict=${dump_dir}/phone_id_map.txt \ --speaker_dict=${dump_dir}/speaker_id_map.txt \ --spk_id=0 diff --git a/examples/vctk/ernie_sat/README.md b/examples/vctk/ernie_sat/README.md index 0a2f9359e..94c7ae25d 100644 --- a/examples/vctk/ernie_sat/README.md +++ b/examples/vctk/ernie_sat/README.md @@ -1,11 +1,10 @@ # ERNIE-SAT with VCTK dataset +ERNIE-SAT speech-text joint pretraining framework, which achieves SOTA results in cross-lingual multi-speaker speech synthesis and cross-lingual speech editing tasks, It can be applied to a series of scenarios such as Speech Editing, personalized Speech Synthesis, and Voice Cloning. -ERNIE-SAT 是可以同时处理中英文的跨语言的语音-语言跨模态大模型,其在语音编辑、个性化语音合成以及跨语言的语音合成等多个任务取得了领先效果。可以应用于语音编辑、个性化合成、语音克隆、同传翻译等一系列场景,该项目供研究使用。 - -## 模型框架 -ERNIE-SAT 中我们提出了两项创新: -- 在预训练过程中将中英双语对应的音素作为输入,实现了跨语言、个性化的软音素映射 -- 采用语言和语音的联合掩码学习实现了语言和语音的对齐 +## Model Framework +In ERNIE-SAT, we propose two innovations: +- In the pretraining process, the phonemes corresponding to Chinese and English are used as input to achieve cross-language and personalized soft phoneme mapping +- The joint mask learning of speech and text is used to realize the alignment of speech and text

diff --git a/examples/vctk/ernie_sat/local/synthesize_e2e.sh b/examples/vctk/ernie_sat/local/synthesize_e2e.sh index fee540169..dcc710447 100755 --- a/examples/vctk/ernie_sat/local/synthesize_e2e.sh +++ b/examples/vctk/ernie_sat/local/synthesize_e2e.sh @@ -14,7 +14,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${BIN_DIR}/synthesize_e2e.py \ --task_name=synthesize \ --wav_path=source/p243_313.wav \ - --old_str='For that reason cover should not be given.' \ + --old_str='For that reason cover should not be given' \ --new_str='I love you very much do you love me' \ --source_lang=en \ --target_lang=en \ @@ -36,8 +36,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${BIN_DIR}/synthesize_e2e.py \ --task_name=edit \ --wav_path=source/p243_313.wav \ - --old_str='For that reason cover should not be given.' \ - --new_str='For that reason cover is not impossible to be given.' \ + --old_str='For that reason cover should not be given' \ + --new_str='For that reason cover is not impossible to be given' \ --source_lang=en \ --target_lang=en \ --erniesat_config=${config_path} \ diff --git a/examples/voxceleb/sv0/README.md b/examples/voxceleb/sv0/README.md index 26c95aca9..7fe759ebc 100644 --- a/examples/voxceleb/sv0/README.md +++ b/examples/voxceleb/sv0/README.md @@ -148,4 +148,4 @@ source path.sh CUDA_VISIBLE_DEVICES= bash ./local/test.sh ./data sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1/model/ conf/ecapa_tdnn.yaml ``` -The performance of the released models are shown in [this](./RESULTS.md) +The performance of the released models are shown in [this](./RESULT.md) diff --git a/examples/wenetspeech/asr1/RESULTS.md b/examples/wenetspeech/asr1/RESULTS.md index cc209db75..f22c652e6 100644 --- a/examples/wenetspeech/asr1/RESULTS.md +++ b/examples/wenetspeech/asr1/RESULTS.md @@ -34,3 +34,22 @@ Pretrain model from http://mobvoi-speech-public.ufile.ucloud.cn/public/wenet/wen | conformer | 32.52 M | conf/conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | - | 0.052534 | | conformer | 32.52 M | conf/conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | - | 0.052915 | | conformer | 32.52 M | conf/conformer.yaml | spec_aug | aishell1 | attention_rescoring | - | 0.047904 | + + +## Conformer Steaming Pretrained Model + +Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz + +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | 16 | 0.056273 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | 16 | 0.078918 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | 16 | 0.079080 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | 16 | 0.054401 | + +| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | -1 | 0.050767 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 | +| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 | diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 813e1e529..8a9849492 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) - cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 92990048d..2d236743a 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -86,7 +86,7 @@ class MultiHeadedAttention(nn.Layer): self, value: paddle.Tensor, scores: paddle.Tensor, - mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool) + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool) ) -> paddle.Tensor: """Compute attention context vector. Args: @@ -127,15 +127,14 @@ class MultiHeadedAttention(nn.Layer): return self.linear_out(x) # (batch, time1, d_model) - def forward( - self, - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) - pos_emb: paddle.Tensor, # paddle.empty([0]) - cache: paddle.Tensor # paddle.zeros([0,0,0,0]) - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute scaled dot product attention. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). @@ -244,15 +243,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): return x - def forward( - self, - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) - pos_emb: paddle.Tensor, # paddle.empty([0]) - cache: paddle.Tensor # paddle.zeros([0,0,0,0]) - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def forward(self, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: query (paddle.Tensor): Query tensor (#batch, time1, size). diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index b35fea5b9..be6056546 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -108,8 +108,8 @@ class ConvolutionModule(nn.Layer): def forward( self, x: paddle.Tensor, - mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) - cache: paddle.Tensor # paddle.zeros([0,0,0,0]) + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute convolution module. Args: diff --git a/paddlespeech/s2t/modules/decoder_layer.py b/paddlespeech/s2t/modules/decoder_layer.py index c8843b723..37b124e84 100644 --- a/paddlespeech/s2t/modules/decoder_layer.py +++ b/paddlespeech/s2t/modules/decoder_layer.py @@ -121,16 +121,11 @@ class DecoderLayer(nn.Layer): if self.concat_after: tgt_concat = paddle.cat( - (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, - paddle.empty([0]), - paddle.zeros([0, 0, 0, 0]))[0]), - dim=-1) + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) x = residual + self.concat_linear1(tgt_concat) else: x = residual + self.dropout( - self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, - paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[ - 0]) + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) if not self.normalize_before: x = self.norm1(x) @@ -139,15 +134,11 @@ class DecoderLayer(nn.Layer): x = self.norm2(x) if self.concat_after: x_concat = paddle.cat( - (x, self.src_attn(x, memory, memory, memory_mask, - paddle.empty([0]), - paddle.zeros([0, 0, 0, 0]))[0]), - dim=-1) + (x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1) x = residual + self.concat_linear2(x_concat) else: x = residual + self.dropout( - self.src_attn(x, memory, memory, memory_mask, - paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0]) + self.src_attn(x, memory, memory, memory_mask)[0]) if not self.normalize_before: x = self.norm2(x) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index cf4e32fa4..2f4ad1b29 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -175,9 +175,7 @@ class BaseEncoder(nn.Layer): decoding_chunk_size, self.static_chunk_size, num_decoding_left_chunks) for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad, - paddle.zeros([0, 0, 0, 0]), - paddle.zeros([0, 0, 0, 0])) + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just @@ -190,9 +188,9 @@ class BaseEncoder(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]) - cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]), - att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool) ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Forward just one chunk Args: @@ -255,7 +253,6 @@ class BaseEncoder(nn.Layer): xs, att_mask, pos_emb, - mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool), att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, cnn_cache=cnn_cache[i:i + 1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) @@ -328,8 +325,7 @@ class BaseEncoder(nn.Layer): chunk_xs = xs[:, cur:end, :] (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache, - paddle.ones([0, 0, 0], dtype=paddle.bool)) + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) outputs.append(y) offset += y.shape[1] diff --git a/paddlespeech/s2t/modules/encoder_layer.py b/paddlespeech/s2t/modules/encoder_layer.py index 4555b535f..dac62bce3 100644 --- a/paddlespeech/s2t/modules/encoder_layer.py +++ b/paddlespeech/s2t/modules/encoder_layer.py @@ -76,10 +76,9 @@ class TransformerEncoderLayer(nn.Layer): x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle. - Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool) - att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) - cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: @@ -106,8 +105,7 @@ class TransformerEncoderLayer(nn.Layer): if self.normalize_before: x = self.norm1(x) - x_att, new_att_cache = self.self_attn( - x, x, x, mask, paddle.empty([0]), cache=att_cache) + x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) if self.concat_after: x_concat = paddle.concat((x, x_att), axis=-1) @@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer): x: paddle.Tensor, mask: paddle.Tensor, pos_emb: paddle.Tensor, - mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool) - att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) - cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Compute encoded features. Args: diff --git a/paddlespeech/s2t/training/trainer.py b/paddlespeech/s2t/training/trainer.py index a7eb9892d..4a69d78a4 100644 --- a/paddlespeech/s2t/training/trainer.py +++ b/paddlespeech/s2t/training/trainer.py @@ -19,6 +19,10 @@ from pathlib import Path import paddle from paddle import distributed as dist +world_size = dist.get_world_size() +if world_size > 1: + dist.init_parallel_env() + from visualdl import LogWriter from paddlespeech.s2t.training.reporter import ObsScope @@ -122,9 +126,6 @@ class Trainer(): else: raise Exception("invalid device") - if self.parallel: - self.init_parallel() - self.checkpoint = Checkpoint( kbest_n=self.config.checkpoint.kbest_n, latest_n=self.config.checkpoint.latest_n) @@ -173,11 +174,6 @@ class Trainer(): """ return self.args.ngpu > 1 - def init_parallel(self): - """Init environment for multiprocess training. - """ - dist.init_parallel_env() - @mp_tools.rank_zero_only def save(self, tag=None, infos: dict=None): """Save checkpoint (model parameters and optimizer states). diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 87d88ee60..5782d7035 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -480,8 +480,7 @@ class PaddleASRConnectionHanddler: self.offset, required_cache_size, att_cache=self.att_cache, - cnn_cache=self.cnn_cache, - att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool)) + cnn_cache=self.cnn_cache) outputs.append(y) # update the global offset, in decoding frame unit diff --git a/paddlespeech/server/engine/engine_warmup.py b/paddlespeech/server/engine/engine_warmup.py index 3751554c2..ff65dff97 100644 --- a/paddlespeech/server/engine/engine_warmup.py +++ b/paddlespeech/server/engine/engine_warmup.py @@ -27,8 +27,10 @@ def warm_up(engine_and_type: str, warm_up_time: int=3) -> bool: sentence = "您好,欢迎使用语音合成服务。" elif tts_engine.lang == 'en': sentence = "Hello and welcome to the speech synthesis service." + elif tts_engine.lang == 'mix': + sentence = "您好,欢迎使用TTS多语种服务。" else: - logger.error("tts engine only support lang: zh or en.") + logger.error("tts engine only support lang: zh or en or mix.") sys.exit(-1) if engine_and_type == "tts_python": diff --git a/paddlespeech/t2s/exps/ernie_sat/align.py b/paddlespeech/t2s/exps/ernie_sat/align.py index 464f51a3b..8dbe685f5 100755 --- a/paddlespeech/t2s/exps/ernie_sat/align.py +++ b/paddlespeech/t2s/exps/ernie_sat/align.py @@ -58,7 +58,7 @@ def _readtg(tg_path: str, lang: str='en', fs: int=24000, n_shift: int=300): durations[-2] += durations[-1] durations = durations[:-1] - # replace ' and 'sil' with 'sp' + # replace '' and 'sil' with 'sp' phones = ['sp' if (phn == '' or phn == 'sil') else phn for phn in phones] if lang == 'en': @@ -195,7 +195,7 @@ def words2phns(text: str, lang='en'): wrd = wrd.upper() if (wrd not in ds): wrd2phns[str(index) + '_' + wrd] = 'spn' - phns.extend('spn') + phns.extend(['spn']) else: wrd2phns[str(index) + '_' + wrd] = word2phns_dict[wrd].split() phns.extend(word2phns_dict[wrd].split()) diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py index 21c9ae044..e450aa1a0 100644 --- a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py @@ -137,9 +137,6 @@ def prep_feats_with_dur(wav_path: str, new_wav = np.concatenate( [wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]]) - # 音频是正常遮住了 - sf.write(str("mask_wav.wav"), new_wav, samplerate=fs) - # 4. get old and new mel span to be mask old_span_bdy = get_span_bdy( mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl) @@ -274,7 +271,8 @@ def get_wav(wav_path: str, new_str: str='', duration_adjust: bool=True, fs: int=24000, - n_shift: int=300): + n_shift: int=300, + task_name: str='synthesize'): outs = get_mlm_output( wav_path=wav_path, @@ -298,9 +296,11 @@ def get_wav(wav_path: str, alt_wav = np.squeeze(alt_wav) old_time_bdy = [n_shift * x for x in old_span_bdy] - wav_replaced = np.concatenate( - [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) - + if task_name == 'edit': + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) + else: + wav_replaced = alt_wav wav_dict = {"origin": wav_org, "output": wav_replaced} return wav_dict @@ -356,7 +356,11 @@ def parse_args(): "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") # ernie sat related - parser.add_argument("--task_name", type=str, help="task name") + parser.add_argument( + "--task_name", + type=str, + choices=['edit', 'synthesize'], + help="task name.") parser.add_argument("--wav_path", type=str, help="path of old wav") parser.add_argument("--old_str", type=str, help="old string") parser.add_argument("--new_str", type=str, help="new string") @@ -410,10 +414,9 @@ if __name__ == '__main__': if args.task_name == 'edit': new_str = new_str elif args.task_name == 'synthesize': - new_str = old_str + new_str + new_str = old_str + ' ' + new_str else: - new_str = old_str + new_str - print("new_str:", new_str) + new_str = old_str + ' ' + new_str # Extractor mel_extractor = LogMelFBank( @@ -467,7 +470,8 @@ if __name__ == '__main__': new_str=new_str, duration_adjust=args.duration_adjust, fs=erniesat_config.fs, - n_shift=erniesat_config.n_shift) + n_shift=erniesat_config.n_shift, + task_name=args.task_name) sf.write( args.output_name, wav_dict['output'], samplerate=erniesat_config.fs) diff --git a/paddlespeech/t2s/frontend/g2pw/__init__.py b/paddlespeech/t2s/frontend/g2pw/__init__.py index 0eaeee5df..89b3af3ca 100644 --- a/paddlespeech/t2s/frontend/g2pw/__init__.py +++ b/paddlespeech/t2s/frontend/g2pw/__init__.py @@ -1 +1 @@ -from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter +from .onnx_api import G2PWOnnxConverter diff --git a/paddlespeech/t2s/frontend/g2pw/dataset.py b/paddlespeech/t2s/frontend/g2pw/dataset.py index 98af5f463..8a1c2e0bf 100644 --- a/paddlespeech/t2s/frontend/g2pw/dataset.py +++ b/paddlespeech/t2s/frontend/g2pw/dataset.py @@ -15,6 +15,10 @@ Credits This code is modified from https://github.com/GitYCC/g2pW """ +from typing import Dict +from typing import List +from typing import Tuple + import numpy as np from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map @@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁' def prepare_onnx_input(tokenizer, - labels, - char2phonemes, - chars, - texts, - query_ids, - phonemes=None, - pos_tags=None, - use_mask=False, - use_char_phoneme=False, - use_pos=False, - window_size=None, - max_len=512): + labels: List[str], + char2phonemes: Dict[str, List[int]], + chars: List[str], + texts: List[str], + query_ids: List[int], + use_mask: bool=False, + window_size: int=None, + max_len: int=512) -> Dict[str, np.array]: if window_size is not None: - truncated_texts, truncated_query_ids = _truncate_texts(window_size, - texts, query_ids) - + truncated_texts, truncated_query_ids = _truncate_texts( + window_size=window_size, texts=texts, query_ids=query_ids) input_ids = [] token_type_ids = [] attention_masks = [] @@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer, query_id = (truncated_query_ids if window_size else query_ids)[idx] try: - tokens, text2token, token2text = tokenize_and_map(tokenizer, text) + tokens, text2token, token2text = tokenize_and_map( + tokenizer=tokenizer, text=text) except Exception: print(f'warning: text "{text}" is invalid') return {} text, query_id, tokens, text2token, token2text = _truncate( - max_len, text, query_id, tokens, text2token, token2text) + max_len=max_len, + text=text, + query_id=query_id, + tokens=tokens, + text2token=text2token, + token2text=token2text) processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] @@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer, return outputs -def _truncate_texts(window_size, texts, query_ids): +def _truncate_texts(window_size: int, texts: List[str], + query_ids: List[int]) -> Tuple[List[str], List[int]]: truncated_texts = [] truncated_query_ids = [] for text, query_id in zip(texts, query_ids): @@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids): return truncated_texts, truncated_query_ids -def _truncate(max_len, text, query_id, tokens, text2token, token2text): +def _truncate(max_len: int, + text: str, + query_id: int, + tokens: List[str], + text2token: List[int], + token2text: List[Tuple[int]]): truncate_len = max_len - 2 if len(tokens) <= truncate_len: return (text, query_id, tokens, text2token, token2text) @@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text): ], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) -def prepare_data(sent_path, lb_path=None): - raw_texts = open(sent_path).read().rstrip().split('\n') - query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts] - texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts] - if lb_path is None: - return texts, query_ids - else: - phonemes = open(lb_path).read().rstrip().split('\n') - return texts, query_ids, phonemes - - -def get_phoneme_labels(polyphonic_chars): +def get_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars]))) char2phonemes = {} for char, phoneme in polyphonic_chars: @@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars): return labels, char2phonemes -def get_char_phoneme_labels(polyphonic_chars): +def get_char_phoneme_labels(polyphonic_chars: List[List[str]] + ) -> Tuple[List[str], Dict[str, List[int]]]: labels = sorted( list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) char2phonemes = {} diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py index 180e8ae15..ad32c4050 100644 --- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py +++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py @@ -17,6 +17,10 @@ Credits """ import json import os +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple import numpy as np import onnxruntime @@ -37,7 +41,8 @@ from paddlespeech.utils.env import MODEL_HOME model_version = '1.1' -def predict(session, onnx_input, labels): +def predict(session, onnx_input: Dict[str, Any], + labels: List[str]) -> Tuple[List[str], List[float]]: all_preds = [] all_confidences = [] probs = session.run([], { @@ -61,10 +66,10 @@ def predict(session, onnx_input, labels): class G2PWOnnxConverter: def __init__(self, - model_dir=MODEL_HOME, - style='bopomofo', - model_source=None, - enable_non_tradional_chinese=False): + model_dir: os.PathLike=MODEL_HOME, + style: str='bopomofo', + model_source: str=None, + enable_non_tradional_chinese: bool=False): uncompress_path = download_and_decompress( g2pw_onnx_models['G2PWModel'][model_version], model_dir) @@ -76,7 +81,8 @@ class G2PWOnnxConverter: os.path.join(uncompress_path, 'g2pW.onnx'), sess_options=sess_options) self.config = load_config( - os.path.join(uncompress_path, 'config.py'), use_default=True) + config_path=os.path.join(uncompress_path, 'config.py'), + use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese @@ -103,9 +109,9 @@ class G2PWOnnxConverter: .strip().split('\n') ] self.labels, self.char2phonemes = get_char_phoneme_labels( - self.polyphonic_chars + polyphonic_chars=self.polyphonic_chars ) if self.config.use_char_phoneme else get_phoneme_labels( - self.polyphonic_chars) + polyphonic_chars=self.polyphonic_chars) self.chars = sorted(list(self.char2phonemes.keys())) @@ -146,7 +152,7 @@ class G2PWOnnxConverter: if self.enable_opencc: self.cc = OpenCC('s2tw') - def _convert_bopomofo_to_pinyin(self, bopomofo): + def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str: tone = bopomofo[-1] assert tone in '12345' component = self.bopomofo_convert_dict.get(bopomofo[:-1]) @@ -156,7 +162,7 @@ class G2PWOnnxConverter: print(f'Warning: "{bopomofo}" cannot convert to pinyin') return None - def __call__(self, sentences): + def __call__(self, sentences: List[str]) -> List[List[str]]: if isinstance(sentences, str): sentences = [sentences] @@ -169,23 +175,25 @@ class G2PWOnnxConverter: sentences = translated_sentences texts, query_ids, sent_ids, partial_results = self._prepare_data( - sentences) + sentences=sentences) if len(texts) == 0: # sentences no polyphonic words return partial_results onnx_input = prepare_onnx_input( - self.tokenizer, - self.labels, - self.char2phonemes, - self.chars, - texts, - query_ids, + tokenizer=self.tokenizer, + labels=self.labels, + char2phonemes=self.char2phonemes, + chars=self.chars, + texts=texts, + query_ids=query_ids, use_mask=self.config.use_mask, - use_char_phoneme=self.config.use_char_phoneme, window_size=None) - preds, confidences = predict(self.session_g2pW, onnx_input, self.labels) + preds, confidences = predict( + session=self.session_g2pW, + onnx_input=onnx_input, + labels=self.labels) if self.config.use_char_phoneme: preds = [pred.split(' ')[1] for pred in preds] @@ -195,7 +203,9 @@ class G2PWOnnxConverter: return results - def _prepare_data(self, sentences): + def _prepare_data( + self, sentences: List[str] + ) -> Tuple[List[str], List[int], List[int], List[List[str]]]: texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): # pypinyin works well for Simplified Chinese than Traditional Chinese diff --git a/paddlespeech/t2s/frontend/g2pw/utils.py b/paddlespeech/t2s/frontend/g2pw/utils.py index ad02c4c1d..ba9ce51ba 100644 --- a/paddlespeech/t2s/frontend/g2pw/utils.py +++ b/paddlespeech/t2s/frontend/g2pw/utils.py @@ -15,10 +15,11 @@ Credits This code is modified from https://github.com/GitYCC/g2pW """ +import os import re -def wordize_and_map(text): +def wordize_and_map(text: str): words = [] index_map_from_text_to_word = [] index_map_from_word_to_text = [] @@ -54,8 +55,8 @@ def wordize_and_map(text): return words, index_map_from_text_to_word, index_map_from_word_to_text -def tokenize_and_map(tokenizer, text): - words, text2word, word2text = wordize_and_map(text) +def tokenize_and_map(tokenizer, text: str): + words, text2word, word2text = wordize_and_map(text=text) tokens = [] index_map_from_token_to_text = [] @@ -82,7 +83,7 @@ def tokenize_and_map(tokenizer, text): return tokens, index_map_from_text_to_token, index_map_from_token_to_text -def _load_config(config_path): +def _load_config(config_path: os.PathLike): import importlib.util spec = importlib.util.spec_from_file_location('__init__', config_path) config = importlib.util.module_from_spec(spec) @@ -130,7 +131,7 @@ default_config_dict = { } -def load_config(config_path, use_default=False): +def load_config(config_path: os.PathLike, use_default: bool=False): config = _load_config(config_path) if use_default: for attr, val in default_config_dict.items(): diff --git a/tests/test_tipc/prepare.sh b/tests/test_tipc/prepare.sh old mode 100644 new mode 100755 index 2a2272813..cb05a1d0f --- a/tests/test_tipc/prepare.sh +++ b/tests/test_tipc/prepare.sh @@ -15,6 +15,7 @@ dataline=$(cat ${FILENAME}) # parser params IFS=$'\n' lines=(${dataline}) +python=python # The training params model_name=$(func_parser_value "${lines[1]}") @@ -68,7 +69,7 @@ if [[ ${MODE} = "benchmark_train" ]];then if [[ ${model_name} == "pwgan" ]]; then # 下载 csmsc 数据集并解压缩 - wget -nc https://weixinxcxdb.oss-cn-beijing.aliyuncs.com/gwYinPinKu/BZNSYP.rar + wget -nc https://paddle-wheel.bj.bcebos.com/benchmark/BZNSYP.rar mkdir -p BZNSYP unrar x BZNSYP.rar BZNSYP wget -nc https://paddlespeech.bj.bcebos.com/Parakeet/benchmark/durations.txt @@ -80,6 +81,10 @@ if [[ ${MODE} = "benchmark_train" ]];then python ../paddlespeech/t2s/exps/gan_vocoder/normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --stats=dump/train/feats_stats.npy fi + echo "barrier start" + PYTHON="${python}" bash test_tipc/barrier.sh + echo "barrier end" + if [[ ${model_name} == "mdtc" ]]; then # 下载 Snips 数据集并解压缩 wget https://paddlespeech.bj.bcebos.com/datasets/hey_snips_kws_4.0.tar.gz.1