From 3f6afc48349fa4257c3174ee53b807a8d885d0f6 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 9 Dec 2022 14:10:45 +0800 Subject: [PATCH] [TTS]Add slim for TTS (#2729) --- examples/csmsc/tts2/local/PTQ_static.sh | 1 + examples/csmsc/tts2/run.sh | 5 + examples/csmsc/tts3/local/PTQ_dynamic.sh | 8 + examples/csmsc/tts3/local/PTQ_static.sh | 8 + examples/csmsc/tts3/run.sh | 13 ++ examples/csmsc/tts3/run_cnndecoder.sh | 5 + examples/csmsc/voc1/local/PTQ_static.sh | 8 + examples/csmsc/voc1/run.sh | 5 + examples/csmsc/voc3/local/PTQ_static.sh | 1 + examples/csmsc/voc3/run.sh | 5 + examples/csmsc/voc5/local/PTQ_static.sh | 1 + examples/csmsc/voc5/run.sh | 5 + paddlespeech/t2s/datasets/am_batch_fn.py | 67 ++++++++ paddlespeech/t2s/datasets/vocoder_batch_fn.py | 55 +++++- paddlespeech/t2s/exps/PTQ_dynamic.py | 80 +++++++++ paddlespeech/t2s/exps/PTQ_static.py | 156 ++++++++++++++++++ paddlespeech/t2s/exps/syn_utils.py | 98 +++++++++++ 17 files changed, 513 insertions(+), 8 deletions(-) create mode 120000 examples/csmsc/tts2/local/PTQ_static.sh create mode 100755 examples/csmsc/tts3/local/PTQ_dynamic.sh create mode 100755 examples/csmsc/tts3/local/PTQ_static.sh create mode 100755 examples/csmsc/voc1/local/PTQ_static.sh create mode 120000 examples/csmsc/voc3/local/PTQ_static.sh create mode 120000 examples/csmsc/voc5/local/PTQ_static.sh create mode 100644 paddlespeech/t2s/exps/PTQ_dynamic.py create mode 100644 paddlespeech/t2s/exps/PTQ_static.py diff --git a/examples/csmsc/tts2/local/PTQ_static.sh b/examples/csmsc/tts2/local/PTQ_static.sh new file mode 120000 index 00000000..f9ce35be --- /dev/null +++ b/examples/csmsc/tts2/local/PTQ_static.sh @@ -0,0 +1 @@ +../../tts3/local/PTQ_static.sh \ No newline at end of file diff --git a/examples/csmsc/tts2/run.sh b/examples/csmsc/tts2/run.sh index 1b608992..6279ec57 100755 --- a/examples/csmsc/tts2/run.sh +++ b/examples/csmsc/tts2/run.sh @@ -72,3 +72,8 @@ fi if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 fi + +# PTQ_static +if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} speedyspeech_csmsc || exit -1 +fi diff --git a/examples/csmsc/tts3/local/PTQ_dynamic.sh b/examples/csmsc/tts3/local/PTQ_dynamic.sh new file mode 100755 index 00000000..5eb64bee --- /dev/null +++ b/examples/csmsc/tts3/local/PTQ_dynamic.sh @@ -0,0 +1,8 @@ +train_output_path=$1 +model_name=$2 +weight_bits=$3 + +python3 ${BIN_DIR}/../PTQ_dynamic.py \ + --inference_dir ${train_output_path}/inference \ + --model_name ${model_name} \ + --weight_bits ${weight_bits} \ No newline at end of file diff --git a/examples/csmsc/tts3/local/PTQ_static.sh b/examples/csmsc/tts3/local/PTQ_static.sh new file mode 100755 index 00000000..a70a77b5 --- /dev/null +++ b/examples/csmsc/tts3/local/PTQ_static.sh @@ -0,0 +1,8 @@ +train_output_path=$1 +model_name=$2 + +python3 ${BIN_DIR}/../PTQ_static.py \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --inference_dir ${train_output_path}/inference \ + --model_name ${model_name} \ + --onnx_forma=True \ No newline at end of file diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh index 14308af4..dd8c9f3e 100755 --- a/examples/csmsc/tts3/run.sh +++ b/examples/csmsc/tts3/run.sh @@ -76,3 +76,16 @@ fi if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 fi + +# PTQ_dynamic +if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + ./local/PTQ_dynamic.sh ${train_output_path} fastspeech2_csmsc 8 + # ./local/PTQ_dynamic.sh ${train_output_path} pwgan_csmsc 8 + # ./local/PTQ_dynamic.sh ${train_output_path} mb_melgan_csmsc 8 + # ./local/PTQ_dynamic.sh ${train_output_path} hifigan_csmsc 8 +fi + +# PTQ_static +if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} fastspeech2_csmsc || exit -1 +fi diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh index 8cc9c5da..96b446c5 100755 --- a/examples/csmsc/tts3/run_cnndecoder.sh +++ b/examples/csmsc/tts3/run_cnndecoder.sh @@ -122,3 +122,8 @@ fi if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict_streaming.sh ${train_output_path} || exit -1 fi + +# PTQ_static +if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} fastspeech2_csmsc || exit -1 +fi \ No newline at end of file diff --git a/examples/csmsc/voc1/local/PTQ_static.sh b/examples/csmsc/voc1/local/PTQ_static.sh new file mode 100755 index 00000000..2e516614 --- /dev/null +++ b/examples/csmsc/voc1/local/PTQ_static.sh @@ -0,0 +1,8 @@ +train_output_path=$1 +model_name=$2 + +python3 ${BIN_DIR}/../../PTQ_static.py \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --inference_dir ${train_output_path}/inference \ + --model_name ${model_name} \ + --onnx_format=True \ No newline at end of file diff --git a/examples/csmsc/voc1/run.sh b/examples/csmsc/voc1/run.sh index cab1ac38..d1122620 100755 --- a/examples/csmsc/voc1/run.sh +++ b/examples/csmsc/voc1/run.sh @@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # synthesize CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi + +# PTQ_static +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} pwgan_csmsc || exit -1 +fi diff --git a/examples/csmsc/voc3/local/PTQ_static.sh b/examples/csmsc/voc3/local/PTQ_static.sh new file mode 120000 index 00000000..fb9f42f6 --- /dev/null +++ b/examples/csmsc/voc3/local/PTQ_static.sh @@ -0,0 +1 @@ +../../voc1/local/PTQ_static.sh \ No newline at end of file diff --git a/examples/csmsc/voc3/run.sh b/examples/csmsc/voc3/run.sh index 3e7d7e2a..d4268ad1 100755 --- a/examples/csmsc/voc3/run.sh +++ b/examples/csmsc/voc3/run.sh @@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # synthesize CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi + +# PTQ_static +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} mb_melgan_csmsc || exit -1 +fi diff --git a/examples/csmsc/voc5/local/PTQ_static.sh b/examples/csmsc/voc5/local/PTQ_static.sh new file mode 120000 index 00000000..fb9f42f6 --- /dev/null +++ b/examples/csmsc/voc5/local/PTQ_static.sh @@ -0,0 +1 @@ +../../voc1/local/PTQ_static.sh \ No newline at end of file diff --git a/examples/csmsc/voc5/run.sh b/examples/csmsc/voc5/run.sh index 3e7d7e2a..90dc9da2 100755 --- a/examples/csmsc/voc5/run.sh +++ b/examples/csmsc/voc5/run.sh @@ -30,3 +30,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # synthesize CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi + +# PTQ_static +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} hifigan_csmsc || exit -1 +fi \ No newline at end of file diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index c00648b1..c95d908d 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -538,3 +538,70 @@ def vits_multi_spk_batch_fn(examples): spk_id = paddle.to_tensor(spk_id) batch["spk_id"] = spk_id return batch + + +# for PaddleSlim +def fastspeech2_single_spk_batch_fn_static(examples): + text = [np.array(item["text"], dtype=np.int64) for item in examples] + text = np.array(text) + # do not need batch axis in infer + text = text[0] + batch = { + "text": text, + } + return batch + + +def fastspeech2_multi_spk_batch_fn_static(examples): + text = [np.array(item["text"], dtype=np.int64) for item in examples] + text = np.array(text) + text = text[0] + batch = { + "text": text, + } + if "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = np.array(spk_id) + spk_id = spk_id[0] + batch["spk_id"] = spk_id + if "spk_emb" in examples[0]: + spk_emb = [ + np.array(item["spk_emb"], dtype=np.float32) for item in examples + ] + spk_emb = np.array(spk_emb) + spk_emb = spk_id[spk_emb] + batch["spk_emb"] = spk_emb + return batch + + +def speedyspeech_single_spk_batch_fn_static(examples): + phones = [np.array(item["phones"], dtype=np.int64) for item in examples] + tones = [np.array(item["tones"], dtype=np.int64) for item in examples] + phones = np.array(phones) + tones = np.array(tones) + phones = phones[0] + tones = tones[0] + batch = { + "phones": phones, + "tones": tones, + } + return batch + + +def speedyspeech_multi_spk_batch_fn_static(examples): + phones = [np.array(item["phones"], dtype=np.int64) for item in examples] + tones = [np.array(item["tones"], dtype=np.int64) for item in examples] + phones = np.array(phones) + tones = np.array(tones) + phones = phones[0] + tones = tones[0] + batch = { + "phones": phones, + "tones": tones, + } + if "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = np.array(spk_id) + spk_id = spk_id[0] + batch["spk_id"] = spk_id + return batch diff --git a/paddlespeech/t2s/datasets/vocoder_batch_fn.py b/paddlespeech/t2s/datasets/vocoder_batch_fn.py index 08748de0..92f32b27 100644 --- a/paddlespeech/t2s/datasets/vocoder_batch_fn.py +++ b/paddlespeech/t2s/datasets/vocoder_batch_fn.py @@ -55,13 +55,12 @@ class Clip(object): Args: batch (list): list of tuple of the pair of audio and features. Audio shape (T, ), features shape(T', C). - Returns: + Returns: + Tensor: + Target signal batch (B, 1, T). Tensor: Auxiliary feature batch (B, C, T'), where T = (T' - 2 * aux_context_window) * hop_size. - Tensor: - Target signal batch (B, 1, T). - """ # check length batch = [ @@ -106,11 +105,7 @@ class Clip(object): if len(x) < c.shape[0] * self.hop_size: x = np.pad(x, (0, c.shape[0] * self.hop_size - len(x)), mode="edge") elif len(x) > c.shape[0] * self.hop_size: - # print( - # f"wave length: ({len(x)}), mel length: ({c.shape[0]}), hop size: ({self.hop_size })" - # ) x = x[:c.shape[0] * self.hop_size] - # check the legnth is valid assert len(x) == c.shape[ 0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})" @@ -218,3 +213,47 @@ class WaveRNNClip(Clip): y = label_2_float(paddle.cast(y, dtype='float32'), self.bits) return x, y, mels + + +# for paddleslim + + +class Clip_static(Clip): + """Collate functor for training vocoders. + """ + + def __call__(self, batch): + """Convert into batch tensors. + + Args: + batch (list): list of tuple of the pair of audio and features. Audio shape (T, ), features shape(T', C). + + Returns: + Dict[str, np.array]: + Auxiliary feature batch (B, C, T'), where + T = (T' - 2 * aux_context_window) * hop_size. + """ + # check length + batch = [ + self._adjust_length(b['wave'], b['feats']) for b in batch + if b['feats'].shape[0] > self.mel_threshold + ] + xs, cs = [b[0] for b in batch], [b[1] for b in batch] + + # make batch with random cut + c_lengths = [c.shape[0] for c in cs] + start_frames = np.array([ + np.random.randint(self.start_offset, cl + self.end_offset) + for cl in c_lengths + ]) + + c_starts = start_frames - self.aux_context_window + c_ends = start_frames + self.batch_max_frames + self.aux_context_window + c_batch = np.stack( + [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]) + # infer axis (T',C) is different with train axis (B, C, T') + # c_batch = c_batch.transpose([0, 2, 1]) # (B, C, T') + # do not need batch axis in infer + c_batch = c_batch[0] + batch = {"logmel": c_batch} + return batch diff --git a/paddlespeech/t2s/exps/PTQ_dynamic.py b/paddlespeech/t2s/exps/PTQ_dynamic.py new file mode 100644 index 00000000..3a38ed81 --- /dev/null +++ b/paddlespeech/t2s/exps/PTQ_dynamic.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import paddle +from paddleslim.quant import quant_post_dynamic + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Slim Dynamic with acoustic model & vocoder.") + # acoustic model + parser.add_argument( + '--model_name', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', + 'fastspeech2_csmsc', + 'fastspeech2_aishell3', + 'fastspeech2_ljspeech', + 'fastspeech2_vctk', + 'tacotron2_csmsc', + 'fastspeech2_mix', + 'pwgan_csmsc', + 'pwgan_aishell3', + 'pwgan_ljspeech', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_aishell3', + 'hifigan_ljspeech', + 'hifigan_vctk', + 'wavernn_csmsc', + ], + help='Choose model type of tts task.') + + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument( + "--weight_bits", + type=int, + default=8, + choices=[8, 16], + help="The bits for the quantized weight, and it should be 8 or 16. Default is 8.", + ) + + args, _ = parser.parse_known_args() + return args + + +# only inference for models trained with csmsc now +def main(): + args = parse_args() + paddle.enable_static() + quant_post_dynamic( + model_dir=args.inference_dir, + save_model_dir=args.inference_dir, + model_filename=args.model_name + ".pdmodel", + params_filename=args.model_name + ".pdiparams", + save_model_filename=args.model_name + "_" + str(args.weight_bits) + + "bits.pdmodel", + save_params_filename=args.model_name + "_" + str(args.weight_bits) + + "bits.pdiparams", + weight_bits=args.weight_bits, ) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/PTQ_static.py b/paddlespeech/t2s/exps/PTQ_static.py new file mode 100644 index 00000000..16b3ae98 --- /dev/null +++ b/paddlespeech/t2s/exps/PTQ_static.py @@ -0,0 +1,156 @@ +import argparse +import random + +import jsonlines +import numpy as np +import paddle +from paddleslim.quant import quant_post_static + +from paddlespeech.t2s.exps.syn_utils import get_dev_dataloader +from paddlespeech.t2s.utils import str2bool + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Slim Static with acoustic model & vocoder.") + + parser.add_argument( + "--batch_size", type=int, default=1, help="Minibatch size.") + parser.add_argument("--batch_num", type=int, default=1, help="Batch number") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + # model_path save_path + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument( + '--model_name', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', + 'fastspeech2_csmsc', + 'fastspeech2_aishell3', + 'fastspeech2_ljspeech', + 'fastspeech2_vctk', + 'fastspeech2_mix', + 'pwgan_csmsc', + 'pwgan_aishell3', + 'pwgan_ljspeech', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_aishell3', + 'hifigan_ljspeech', + 'hifigan_vctk', + ], + help='Choose model type of tts task.') + + parser.add_argument( + "--algo", type=str, default='avg', help="calibration algorithm.") + parser.add_argument( + "--round_type", + type=str, + default='round', + help="The method of converting the quantized weights.") + parser.add_argument( + "--hist_percent", + type=float, + default=0.9999, + help="The percentile of algo:hist.") + parser.add_argument( + "--is_full_quantize", + type=str2bool, + default=False, + help="Whether is full quantization or not.") + parser.add_argument( + "--bias_correction", + type=str2bool, + default=False, + help="Whether to use bias correction.") + parser.add_argument( + "--ce_test", type=str2bool, default=False, help="Whether to CE test.") + parser.add_argument( + "--onnx_format", + type=str2bool, + default=False, + help="Whether to export the quantized model with format of ONNX.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument( + "--quantizable_op_type", + type=list, + nargs='+', + default=[ + "conv2d_transpose", "conv2d", "depthwise_conv2d", "mul", "matmul", + "matmul_v2" + ], + help="The list of op types that will be quantized.") + + args = parser.parse_args() + return args + + +def quantize(args): + shuffle = True + if args.ce_test: + # set seed + seed = 111 + np.random.seed(seed) + paddle.seed(seed) + random.seed(seed) + shuffle = False + + place = paddle.CUDAPlace(0) if args.ngpu > 0 else paddle.CPUPlace() + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + + dataloader = get_dev_dataloader( + dev_metadata=dev_metadata, + am=args.model_name, + batch_size=args.batch_size, + speaker_dict=args.speaker_dict, + shuffle=shuffle) + + exe = paddle.static.Executor(place) + exe.run() + + print("onnx_format:", args.onnx_format) + + quant_post_static( + executor=exe, + model_dir=args.inference_dir, + quantize_model_path=args.inference_dir + "/" + args.model_name + + "_quant", + data_loader=dataloader, + model_filename=args.model_name + ".pdmodel", + params_filename=args.model_name + ".pdiparams", + save_model_filename=args.model_name + ".pdmodel", + save_params_filename=args.model_name + ".pdiparams", + batch_size=args.batch_size, + algo=args.algo, + round_type=args.round_type, + hist_percent=args.hist_percent, + is_full_quantize=args.is_full_quantize, + bias_correction=args.bias_correction, + onnx_format=args.onnx_format, + quantizable_op_type=args.quantizable_op_type) + + +def main(): + args = parse_args() + new_quantizable_op_type = [] + for item in args.quantizable_op_type: + new_quantizable_op_type.append(''.join(item)) + args.quantizable_op_type = new_quantizable_op_type + paddle.enable_static() + quantize(args) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 82b71848..0ac79981 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -25,10 +25,13 @@ import onnxruntime as ort import paddle from paddle import inference from paddle import jit +from paddle.io import DataLoader from paddle.static import InputSpec from yacs.config import CfgNode +from paddlespeech.t2s.datasets.am_batch_fn import * from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip_static from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.mix_frontend import MixFrontend from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -118,6 +121,7 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): return sentences +# am only def get_test_dataset(test_metadata: List[Dict[str, Any]], am: str, speaker_dict: Optional[os.PathLike]=None, @@ -158,6 +162,100 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], return test_dataset +# am and voc, for PTQ_static +def get_dev_dataloader(dev_metadata: List[Dict[str, Any]], + am: str, + batch_size: int=1, + speaker_dict: Optional[os.PathLike]=None, + voice_cloning: bool=False, + n_shift: int=300, + batch_max_steps: int=16200, + shuffle: bool=True): + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] + converters = {} + if am_name == 'fastspeech2': + fields = ["utt_id", "text"] + if am_dataset in {"aishell3", "vctk", + "mix"} and speaker_dict is not None: + print("multiple speaker fastspeech2!") + collate_fn = fastspeech2_multi_spk_batch_fn_static + fields += ["spk_id"] + elif voice_cloning: + print("voice cloning!") + collate_fn = fastspeech2_multi_spk_batch_fn_static + fields += ["spk_emb"] + else: + print("single speaker fastspeech2!") + collate_fn = fastspeech2_single_spk_batch_fn_static + elif am_name == 'speedyspeech': + fields = ["utt_id", "phones", "tones"] + if am_dataset in {"aishell3", "vctk", + "mix"} and speaker_dict is not None: + print("multiple speaker speedyspeech!") + collate_fn = speedyspeech_multi_spk_batch_fn_static + fields += ["spk_id"] + else: + print("single speaker speedyspeech!") + collate_fn = speedyspeech_single_spk_batch_fn_static + fields = ["utt_id", "phones", "tones"] + elif am_name == 'tacotron2': + fields = ["utt_id", "text"] + if voice_cloning: + print("voice cloning!") + collate_fn = tacotron2_multi_spk_batch_fn_static + fields += ["spk_emb"] + else: + print("single speaker tacotron2!") + collate_fn = tacotron2_single_spk_batch_fn_static + else: + print("voc dataloader") + + # am + if am_name not in {'pwgan', 'mb_melgan', 'hifigan'}: + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=shuffle, + drop_last=False, + batch_size=batch_size, + collate_fn=collate_fn) + # vocoder + else: + # pwgan: batch_max_steps: 25500 aux_context_window: 2 + # mb_melgan: batch_max_steps: 16200 aux_context_window 0 + # hifigan: batch_max_steps: 8400 aux_context_window 0 + aux_context_window = 0 + if am_name == 'pwgan': + aux_context_window = 2 + + train_batch_fn = Clip_static( + batch_max_steps=batch_max_steps, + hop_size=n_shift, + aux_context_window=aux_context_window) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=shuffle, + drop_last=False, + batch_size=batch_size, + collate_fn=train_batch_fn) + + return dev_dataloader + + # frontend def get_frontend(lang: str='zh', phones_dict: Optional[os.PathLike]=None,