From 593f247460319be9d642ab894c593d06e627fa52 Mon Sep 17 00:00:00 2001 From: liangym Date: Tue, 18 Oct 2022 10:06:56 +0000 Subject: [PATCH] new_fs2 --- examples/zh_en_tts/tts3/conf/default.yaml | 4 ++ examples/zh_en_tts/tts3/local/train.sh | 2 +- examples/zh_en_tts/tts3/run.sh | 6 +- .../t2s/models/fastspeech2/fastspeech2.py | 71 +++++++++++++++++-- .../models/fastspeech2/fastspeech2_updater.py | 31 +++++--- .../multi_speakers/gradient_reversal.py | 60 ++++++++++++++++ .../multi_speakers/speaker_classifier.py | 50 +++++++++++++ 7 files changed, 205 insertions(+), 19 deletions(-) create mode 100644 paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py create mode 100644 paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py diff --git a/examples/zh_en_tts/tts3/conf/default.yaml b/examples/zh_en_tts/tts3/conf/default.yaml index e65b5d0ec..06bf1fcc1 100644 --- a/examples/zh_en_tts/tts3/conf/default.yaml +++ b/examples/zh_en_tts/tts3/conf/default.yaml @@ -74,6 +74,9 @@ model: stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder spk_embed_dim: 256 # speaker embedding dimension spk_embed_integration_type: concat # speaker embedding integration type + enable_speaker_classifier: True # Whether to use speaker classifier module + hidden_sc_dim: 256 # The hidden layer dim of speaker classifier + @@ -82,6 +85,7 @@ model: ########################################################### updater: use_masking: True # whether to apply masking for padded part in loss calculation + spk_loss_scale: 0.02 # The scales of speaker classifier loss ########################################################### diff --git a/examples/zh_en_tts/tts3/local/train.sh b/examples/zh_en_tts/tts3/local/train.sh index 1da72f117..3a5076505 100755 --- a/examples/zh_en_tts/tts3/local/train.sh +++ b/examples/zh_en_tts/tts3/local/train.sh @@ -8,6 +8,6 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=2 \ + --ngpu=1 \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt diff --git a/examples/zh_en_tts/tts3/run.sh b/examples/zh_en_tts/tts3/run.sh index 12f99081a..a0c58f35c 100755 --- a/examples/zh_en_tts/tts3/run.sh +++ b/examples/zh_en_tts/tts3/run.sh @@ -3,9 +3,9 @@ set -e source path.sh -gpus=0,1 -stage=0 -stop_stage=100 +gpus=0 +stage=1 +stop_stage=1 datasets_root_dir=~/datasets mfa_root_dir=./mfa_results/ diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 9905765db..1db7973b8 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -37,6 +37,8 @@ from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder +from paddlespeech.t2s.modules.multi_speakers.speaker_classifier import SpeakerClassifier +from paddlespeech.t2s.modules.multi_speakers.gradient_reversal import GradientReversalLayer class FastSpeech2(nn.Layer): @@ -138,7 +140,10 @@ class FastSpeech2(nn.Layer): # training related init_type: str="xavier_uniform", init_enc_alpha: float=1.0, - init_dec_alpha: float=1.0, ): + init_dec_alpha: float=1.0, + # speaker classifier + enable_speaker_classifier: bool=False, + hidden_sc_dim: int=256,): """Initialize FastSpeech2 module. Args: idim (int): @@ -268,6 +273,10 @@ class FastSpeech2(nn.Layer): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. + enable_speaker_classifier (bool): + Whether to use speaker classifier module + hidden_sc_dim (int): + The hidden layer dim of speaker classifier """ assert check_argument_types() @@ -281,6 +290,9 @@ class FastSpeech2(nn.Layer): self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc + self.hidden_sc_dim = hidden_sc_dim + self.spk_num = spk_num + self.enable_speaker_classifier = enable_speaker_classifier self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: @@ -373,6 +385,11 @@ class FastSpeech2(nn.Layer): self.tone_projection = nn.Linear(adim + self.tone_embed_dim, adim) + if self.spk_num and self.enable_speaker_classifier: + # set lambda = 1 + self.grad_reverse = GradientReversalLayer(1) + self.speaker_classifier = SpeakerClassifier(idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) + # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, @@ -547,7 +564,7 @@ class FastSpeech2(nn.Layer): if tone_id is not None: tone_id = paddle.cast(tone_id, 'int64') # forward propagation - before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( + before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward( xs, ilens, olens, @@ -564,7 +581,7 @@ class FastSpeech2(nn.Layer): max_olen = max(olens) ys = ys[:, :max_olen] - return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens + return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits def _forward(self, xs: paddle.Tensor, @@ -584,6 +601,12 @@ class FastSpeech2(nn.Layer): # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) + if self.spk_num and self.enable_speaker_classifier: + hs_for_spk_cls = self.grad_reverse(hs) + spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens) + else: + spk_logits = None + # integrate speaker embedding if self.spk_embed_dim is not None: # spk_emb has a higher priority than spk_id @@ -676,7 +699,7 @@ class FastSpeech2(nn.Layer): after_outs = before_outs + self.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - return before_outs, after_outs, d_outs, p_outs, e_outs + return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits def encoder_infer( self, @@ -1058,6 +1081,7 @@ class FastSpeech2Loss(nn.Layer): self.l1_criterion = nn.L1Loss(reduction=reduction) self.mse_criterion = nn.MSELoss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction) + self.ce_criterion = nn.CrossEntropyLoss() def forward( self, @@ -1072,7 +1096,9 @@ class FastSpeech2Loss(nn.Layer): es: paddle.Tensor, ilens: paddle.Tensor, olens: paddle.Tensor, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + spk_logits: paddle.Tensor=None, + spk_ids: paddle.Tensor=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,]: """Calculate forward propagation. Args: @@ -1098,11 +1124,20 @@ class FastSpeech2Loss(nn.Layer): Batch of the lengths of each input (B,). olens(Tensor): Batch of the lengths of each target (B,). + spk_logits(Option[Tensor]): + Batch of outputs after speaker classifier (B, Lmax, num_spk) + spk_ids(Option[Tensor]): + Batch of target spk_id (B,) + Returns: """ + speaker_loss = 0.0 + + import pdb;pdb.set_trace() + # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) @@ -1124,6 +1159,15 @@ class FastSpeech2Loss(nn.Layer): ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) es = es.masked_select(pitch_masks.broadcast_to(es.shape)) + if spk_logits is not None and spk_ids is not None: + spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], None) + spk_logits = paddle.reshape(spk_logits, [-1, spk_logits.shape[-1]]) + mask_index = spk_logits.abs().sum(axis=1)!=0 + spk_ids = spk_ids[mask_index] + spk_logits = spk_logits[mask_index] + speaker_loss = self.ce_criterion(spk_logits, spk_ids)/spk_logits.shape[0] + print("sssssssssssssssssss") + # calculate loss l1_loss = self.l1_criterion(before_outs, ys) if after_outs is not None: @@ -1131,6 +1175,21 @@ class FastSpeech2Loss(nn.Layer): duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) + + # if spk_logits is None or spk_ids is None: + # speaker_loss = 0.0 + # else: + # Tmax = spk_logits.shape[1] + # batch_num = spk_logits.shape[0] + # spk_ids = + # speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_num + + # index_into_spkr_logits = batched_speakers.repeat_interleave(spkr_clsfir_logits.shape[1]) + # spkr_clsfir_logits = spkr_clsfir_logits.reshape(-1, spkr_clsfir_logits.shape[-1]) + # mask_index = spkr_clsfir_logits.abs().sum(dim=1)!=0 + # spkr_clsfir_logits = spkr_clsfir_logits[mask_index] + # index_into_spkr_logits = index_into_spkr_logits[mask_index] + # speaker_loss = self.ce_criterion(spkr_clsfir_logits, index_into_spkr_logits)/batched_speakers.shape[0] # make weighted mask and apply it if self.use_weighted_masking: @@ -1161,4 +1220,4 @@ class FastSpeech2Loss(nn.Layer): energy_loss = energy_loss.masked_select( pitch_masks.broadcast_to(energy_loss.shape)).sum() - return l1_loss, duration_loss, pitch_loss, energy_loss + return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 92aa9dfc7..1eb0f60fd 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -37,18 +37,20 @@ class FastSpeech2Updater(StandardUpdater): dataloader: DataLoader, init_state=None, use_masking: bool=False, + spk_loss_scale: float=0.02, use_weighted_masking: bool=False, output_dir: Path=None): super().__init__(model, optimizer, dataloader, init_state=None) self.criterion = FastSpeech2Loss( - use_masking=use_masking, use_weighted_masking=use_weighted_masking) + use_masking=use_masking, use_weighted_masking=use_weighted_masking,) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) logger.addHandler(self.filehandler) self.logger = logger self.msg = "" + self.spk_loss_scale = spk_loss_scale def update_core(self, batch): self.msg = "Rank: {}, ".format(dist.get_rank()) @@ -60,7 +62,7 @@ class FastSpeech2Updater(StandardUpdater): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], speech=batch["speech"], @@ -71,7 +73,7 @@ class FastSpeech2Updater(StandardUpdater): spk_id=spk_id, spk_emb=spk_emb) - l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, @@ -82,9 +84,11 @@ class FastSpeech2Updater(StandardUpdater): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens) + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id,) - loss = l1_loss + duration_loss + pitch_loss + energy_loss + loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss optimizer = self.optimizer optimizer.clear_grad() @@ -96,11 +100,16 @@ class FastSpeech2Updater(StandardUpdater): report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) + report("train/speaker_loss", float(speaker_loss)) + report("train/speaker_loss_0.02", float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) + losses_dict["energy_loss"] = float(energy_loss) + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["speaker_loss_0.02"] = float(self.spk_loss_scale * speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) @@ -112,6 +121,7 @@ class FastSpeech2Evaluator(StandardEvaluator): dataloader: DataLoader, use_masking: bool=False, use_weighted_masking: bool=False, + spk_loss_scale: float=0.02, output_dir: Path=None): super().__init__(model, dataloader) @@ -120,6 +130,7 @@ class FastSpeech2Evaluator(StandardEvaluator): logger.addHandler(self.filehandler) self.logger = logger self.msg = "" + self.spk_loss_scale = spk_loss_scale self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) @@ -133,7 +144,7 @@ class FastSpeech2Evaluator(StandardEvaluator): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], speech=batch["speech"], @@ -144,7 +155,7 @@ class FastSpeech2Evaluator(StandardEvaluator): spk_id=spk_id, spk_emb=spk_emb) - l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, @@ -155,8 +166,10 @@ class FastSpeech2Evaluator(StandardEvaluator): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens, ) - loss = l1_loss + duration_loss + pitch_loss + energy_loss + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id,) + loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss report("eval/loss", float(loss)) report("eval/l1_loss", float(l1_loss)) diff --git a/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py b/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py new file mode 100644 index 000000000..5250f1df1 --- /dev/null +++ b/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.autograd import PyLayer +import paddle.nn as nn + +class GradientReversalFunction(PyLayer): + """Gradient Reversal Layer from: + Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) + + Forward pass is the identity function. In the backward pass, + the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) + """ + + @staticmethod + def forward(ctx, x, lambda_=1): + """Forward in networks + """ + ctx.save_for_backward(lambda_) + return x.clone() + + @staticmethod + def backward(ctx, grads): + """Backward in networks + """ + lambda_, = ctx.saved_tensor() + dx = -lambda_ * grads + #return dx + return paddle.clip(dx, min=-0.5, max=0.5) + + +class GradientReversalLayer(nn.Layer): + """Gradient Reversal Layer from: + Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) + + Forward pass is the identity function. In the backward pass, + the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) + """ + + def __init__(self, lambda_=1): + super(GradientReversalLayer, self).__init__() + self.lambda_ = lambda_ + + def forward(self, x): + """Forward in networks + """ + return GradientReversalFunction.apply(x, self.lambda_) + diff --git a/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py b/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py new file mode 100644 index 000000000..a64f6d5b8 --- /dev/null +++ b/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning) + +from paddle import nn +import paddle +from typeguard import check_argument_types + +class SpeakerClassifier(nn.Layer): + + def __init__(self, idim: int, hidden_sc_dim: int, spk_num: int, ): + assert check_argument_types() + super().__init__() + # store hyperparameters + self.idim = idim + self.hidden_sc_dim = hidden_sc_dim + self.spk_num = spk_num + + self.model = nn.Sequential(nn.Linear(self.idim, self.hidden_sc_dim), + nn.Linear(self.hidden_sc_dim, self.spk_num)) + + def parse_outputs(self, out, text_lengths): + mask = paddle.arange(out.shape[1]).expand([out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) + out = paddle.transpose(out, perm=[2, 0, 1]) + out = out * mask + out = paddle.transpose(out, perm=[1, 2, 0]) + return out + + def forward(self, encoder_outputs, text_lengths): + """ + encoder_outputs = [batch_size, seq_len, encoder_embedding_size] + text_lengths = [batch_size] + + log probabilities of speaker classification = [batch_size, seq_len, spk_num] + """ + + out = self.model(encoder_outputs) + out = self.parse_outputs(out, text_lengths) + return out