From 2edc79f96548c9d04362820f7c7394c765c317ea Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 09:51:37 +0000 Subject: [PATCH] fix clip bug --- examples/vctk/vc3/local/train.sh | 3 +- paddlespeech/t2s/datasets/am_batch_fn.py | 10 +++--- paddlespeech/t2s/exps/starganv2_vc/train.py | 17 +++++++++- .../t2s/models/starganv2_vc/losses.py | 33 ++++++++++--------- .../starganv2_vc/starganv2_vc_updater.py | 4 +-- 5 files changed, 44 insertions(+), 23 deletions(-) diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh index bdd8deaed..d4ea02da8 100755 --- a/examples/vctk/vc3/local/train.sh +++ b/examples/vctk/vc3/local/train.sh @@ -8,4 +8,5 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 + --ngpu=1 \ + --speaker-dict=dump/speaker_id_map.txt diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index ae46f1e1a..85959aa25 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -820,12 +820,13 @@ class StarGANv2VCCollateFn: self.max_mel_length = max_mel_length def random_clip(self, mel: np.array): - # [80, T] - mel_length = mel.shape[1] + # [T, 80] + mel_length = mel.shape[0] if mel_length > self.max_mel_length: random_start = np.random.randint(0, mel_length - self.max_mel_length) - mel = mel[:, random_start:random_start + self.max_mel_length] + + mel = mel[random_start:random_start + self.max_mel_length, :] return mel def __call__(self, exmaples): @@ -843,8 +844,9 @@ class StarGANv2VCCollateFn: mel = [self.random_clip(item["mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] - + print("mel[0].shape after batch_sequences:", mel[0].shape) mel = batch_sequences(mel) + print("mel.shape after batch_sequences:", mel.shape) ref_mel = batch_sequences(ref_mel) ref_mel_2 = batch_sequences(ref_mel_2) diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py index 616591e79..94fa3032c 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/train.py +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -113,6 +113,16 @@ def train_sp(args, config): model_version = '1.0' uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], MODEL_HOME) + # 根据 speaker 的个数修改 num_domains + # 源码的预训练模型和 default.yaml 里面默认是 20 + if args.speaker_dict is not None: + with open(args.speaker_dict, 'rt', encoding='utf-8') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + config['mapping_network_params']['num_domains'] = spk_num + config['style_encoder_params']['num_domains'] = spk_num + config['discriminator_params']['num_domains'] = spk_num generator = Generator(**config['generator_params']) mapping_network = MappingNetwork(**config['mapping_network_params']) @@ -123,7 +133,7 @@ def train_sp(args, config): jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') asr_model_dir = os.path.join(uncompress_path, 'asr.pdz') - F0_model = JDCNet(num_class=1, seq_len=192) + F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length']) F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) F0_model.eval() @@ -234,6 +244,11 @@ def main(): parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") args = parser.parse_args() diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index aef7559f9..f4a308da0 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -21,33 +21,35 @@ from .transforms import build_transforms # 这些都写到 updater 里 -def compute_d_loss(nets: Dict[str, Any], - x_real: paddle.Tensor, - y_org: paddle.Tensor, - y_trg: paddle.Tensor, - z_trg: paddle.Tensor=None, - x_ref: paddle.Tensor=None, - use_r1_reg: bool=True, - use_adv_cls: bool=False, - use_con_reg: bool=False, - lambda_reg: float=1., - lambda_adv_cls: float=0.1, - lambda_con_reg: float=10.): +def compute_d_loss( + nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trg: paddle.Tensor=None, + x_ref: paddle.Tensor=None, + # TODO: should be True here, but r1_reg has some bug now + use_r1_reg: bool=False, + use_adv_cls: bool=False, + use_con_reg: bool=False, + lambda_reg: float=1., + lambda_adv_cls: float=0.1, + lambda_con_reg: float=10.): assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) - # R1 regularizaition (https://arxiv.org/abs/1801.04406v4) if use_r1_reg: loss_reg = r1_reg(out, x_real) else: - loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + # loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_reg = paddle.zeros([1]) # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724) - loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_con_reg = paddle.zeros([1]) if use_con_reg: t = build_transforms() out_aug = nets['discriminator'](t(x_real).detach(), y_org) @@ -119,6 +121,7 @@ def compute_g_loss(nets: Dict[str, Any], # compute ASR/F0 features (real) with paddle.no_grad(): + print("x_real.shape:", x_real.shape) F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) ASR_real = nets['asr_model'].get_feature(x_real) diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 6a77fbb2c..1b811a3f7 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -259,7 +259,7 @@ class StarGANv2VCEvaluator(StandardEvaluator): y_org=y_org, y_trg=y_trg, z_trg=z_trg, - use_r1_reg=False, + use_r1_reg=self.use_r1_reg, use_adv_cls=use_adv_cls, **self.d_loss_params) @@ -269,7 +269,7 @@ class StarGANv2VCEvaluator(StandardEvaluator): y_org=y_org, y_trg=y_trg, x_ref=x_ref, - use_r1_reg=False, + use_r1_reg=self.use_r1_reg, use_adv_cls=use_adv_cls, **self.d_loss_params)