diff --git a/examples/zh_en_tts/tts3/README.md b/examples/zh_en_tts/tts3/README.md index b4b683089..012028007 100644 --- a/examples/zh_en_tts/tts3/README.md +++ b/examples/zh_en_tts/tts3/README.md @@ -116,6 +116,8 @@ optional arguments: 5. `--phones-dict` is the path of the phone vocabulary file. 6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2. +We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: `config["model"]["enable_speaker_classifier"]`, `config["model"]["hidden_sc_dim"]` and `config["updater"]["spk_loss_scale"]` in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself. + ### Synthesizing We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder. diff --git a/examples/zh_en_tts/tts3/conf/default.yaml b/examples/zh_en_tts/tts3/conf/default.yaml index e65b5d0ec..efa8b3ea2 100644 --- a/examples/zh_en_tts/tts3/conf/default.yaml +++ b/examples/zh_en_tts/tts3/conf/default.yaml @@ -74,6 +74,9 @@ model: stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder spk_embed_dim: 256 # speaker embedding dimension spk_embed_integration_type: concat # speaker embedding integration type + enable_speaker_classifier: False # Whether to use speaker classifier module + hidden_sc_dim: 256 # The hidden layer dim of speaker classifier + @@ -82,6 +85,7 @@ model: ########################################################### updater: use_masking: True # whether to apply masking for padded part in loss calculation + spk_loss_scale: 0.02 # The scales of speaker classifier loss ########################################################### diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py index 10e023d0c..d31e62a82 100644 --- a/paddlespeech/t2s/exps/fastspeech2/train.py +++ b/paddlespeech/t2s/exps/fastspeech2/train.py @@ -145,17 +145,27 @@ def train_sp(args, config): # copy conf to output_dir shutil.copyfile(args.config, output_dir / config_name) + if "enable_speaker_classifier" in config.model: + enable_spk_cls = config.model.enable_speaker_classifier + else: + enable_spk_cls = False + updater = FastSpeech2Updater( model=model, optimizer=optimizer, dataloader=train_dataloader, output_dir=output_dir, - **config["updater"]) + enable_spk_cls=enable_spk_cls, + **config["updater"], ) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) evaluator = FastSpeech2Evaluator( - model, dev_dataloader, output_dir=output_dir, **config["updater"]) + model, + dev_dataloader, + output_dir=output_dir, + enable_spk_cls=enable_spk_cls, + **config["updater"], ) if dist.get_rank() == 0: trainer.extend(evaluator, trigger=(1, "epoch")) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 9905765db..0eb44beb6 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -25,6 +25,8 @@ import paddle.nn.functional as F from paddle import nn from typeguard import check_argument_types +from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer +from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_pad_mask @@ -138,7 +140,10 @@ class FastSpeech2(nn.Layer): # training related init_type: str="xavier_uniform", init_enc_alpha: float=1.0, - init_dec_alpha: float=1.0, ): + init_dec_alpha: float=1.0, + # speaker classifier + enable_speaker_classifier: bool=False, + hidden_sc_dim: int=256, ): """Initialize FastSpeech2 module. Args: idim (int): @@ -268,6 +273,10 @@ class FastSpeech2(nn.Layer): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. + enable_speaker_classifier (bool): + Whether to use speaker classifier module + hidden_sc_dim (int): + The hidden layer dim of speaker classifier """ assert check_argument_types() @@ -281,6 +290,9 @@ class FastSpeech2(nn.Layer): self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc + self.hidden_sc_dim = hidden_sc_dim + self.spk_num = spk_num + self.enable_speaker_classifier = enable_speaker_classifier self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: @@ -373,6 +385,12 @@ class FastSpeech2(nn.Layer): self.tone_projection = nn.Linear(adim + self.tone_embed_dim, adim) + if self.spk_num and self.enable_speaker_classifier: + # set lambda = 1 + self.grad_reverse = GradientReversalLayer(1) + self.speaker_classifier = SpeakerClassifier( + idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) + # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, @@ -547,7 +565,7 @@ class FastSpeech2(nn.Layer): if tone_id is not None: tone_id = paddle.cast(tone_id, 'int64') # forward propagation - before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( + before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward( xs, ilens, olens, @@ -564,7 +582,7 @@ class FastSpeech2(nn.Layer): max_olen = max(olens) ys = ys[:, :max_olen] - return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens + return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits def _forward(self, xs: paddle.Tensor, @@ -584,6 +602,12 @@ class FastSpeech2(nn.Layer): # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) + if self.spk_num and self.enable_speaker_classifier and not is_inference: + hs_for_spk_cls = self.grad_reverse(hs) + spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens) + else: + spk_logits = None + # integrate speaker embedding if self.spk_embed_dim is not None: # spk_emb has a higher priority than spk_id @@ -676,7 +700,7 @@ class FastSpeech2(nn.Layer): after_outs = before_outs + self.postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - return before_outs, after_outs, d_outs, p_outs, e_outs + return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits def encoder_infer( self, @@ -771,7 +795,7 @@ class FastSpeech2(nn.Layer): es = e.unsqueeze(0) if e is not None else None # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._forward( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, ds=ds, @@ -783,7 +807,7 @@ class FastSpeech2(nn.Layer): is_inference=True) else: # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._forward( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, is_inference=True, @@ -791,6 +815,7 @@ class FastSpeech2(nn.Layer): spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id) + return outs[0], d_outs[0], p_outs[0], e_outs[0] def _integrate_with_spk_embed(self, hs, spk_emb): @@ -1058,6 +1083,7 @@ class FastSpeech2Loss(nn.Layer): self.l1_criterion = nn.L1Loss(reduction=reduction) self.mse_criterion = nn.MSELoss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction) + self.ce_criterion = nn.CrossEntropyLoss() def forward( self, @@ -1072,7 +1098,10 @@ class FastSpeech2Loss(nn.Layer): es: paddle.Tensor, ilens: paddle.Tensor, olens: paddle.Tensor, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + spk_logits: paddle.Tensor=None, + spk_ids: paddle.Tensor=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, ]: """Calculate forward propagation. Args: @@ -1098,11 +1127,18 @@ class FastSpeech2Loss(nn.Layer): Batch of the lengths of each input (B,). olens(Tensor): Batch of the lengths of each target (B,). + spk_logits(Option[Tensor]): + Batch of outputs after speaker classifier (B, Lmax, num_spk) + spk_ids(Option[Tensor]): + Batch of target spk_id (B,) + Returns: """ + speaker_loss = 0.0 + # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) @@ -1124,6 +1160,16 @@ class FastSpeech2Loss(nn.Layer): ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape)) es = es.masked_select(pitch_masks.broadcast_to(es.shape)) + if spk_logits is not None and spk_ids is not None: + batch_size = spk_ids.shape[0] + spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], + None) + spk_logits = paddle.reshape(spk_logits, + [-1, spk_logits.shape[-1]]) + mask_index = spk_logits.abs().sum(axis=1) != 0 + spk_ids = spk_ids[mask_index] + spk_logits = spk_logits[mask_index] + # calculate loss l1_loss = self.l1_criterion(before_outs, ys) if after_outs is not None: @@ -1132,6 +1178,9 @@ class FastSpeech2Loss(nn.Layer): pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) + if spk_logits is not None and spk_ids is not None: + speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size + # make weighted mask and apply it if self.use_weighted_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) @@ -1161,4 +1210,4 @@ class FastSpeech2Loss(nn.Layer): energy_loss = energy_loss.masked_select( pitch_masks.broadcast_to(energy_loss.shape)).sum() - return l1_loss, duration_loss, pitch_loss, energy_loss + return l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 92aa9dfc7..b398267e6 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -14,6 +14,7 @@ import logging from pathlib import Path +from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer @@ -23,6 +24,7 @@ from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -31,24 +33,30 @@ logger.setLevel(logging.INFO) class FastSpeech2Updater(StandardUpdater): - def __init__(self, - model: Layer, - optimizer: Optimizer, - dataloader: DataLoader, - init_state=None, - use_masking: bool=False, - use_weighted_masking: bool=False, - output_dir: Path=None): + def __init__( + self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state=None, + use_masking: bool=False, + spk_loss_scale: float=0.02, + use_weighted_masking: bool=False, + output_dir: Path=None, + enable_spk_cls: bool=False, ): super().__init__(model, optimizer, dataloader, init_state=None) self.criterion = FastSpeech2Loss( - use_masking=use_masking, use_weighted_masking=use_weighted_masking) + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, ) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) logger.addHandler(self.filehandler) self.logger = logger self.msg = "" + self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls def update_core(self, batch): self.msg = "Rank: {}, ".format(dist.get_rank()) @@ -60,18 +68,33 @@ class FastSpeech2Updater(StandardUpdater): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( - text=batch["text"], - text_lengths=batch["text_lengths"], - speech=batch["speech"], - speech_lengths=batch["speech_lengths"], - durations=batch["durations"], - pitch=batch["pitch"], - energy=batch["energy"], - spk_id=spk_id, - spk_emb=spk_emb) - - l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, @@ -82,9 +105,12 @@ class FastSpeech2Updater(StandardUpdater): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens) + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id, ) - loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss = self.spk_loss_scale * speaker_loss + loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss optimizer = self.optimizer optimizer.clear_grad() @@ -96,11 +122,18 @@ class FastSpeech2Updater(StandardUpdater): report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) + if self.enable_spk_cls: + report("train/speaker_loss", float(speaker_loss)) + report("train/scaled_speaker_loss", float(scaled_speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) + losses_dict["energy_loss"] = float(energy_loss) + if self.enable_spk_cls: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) @@ -112,7 +145,9 @@ class FastSpeech2Evaluator(StandardEvaluator): dataloader: DataLoader, use_masking: bool=False, use_weighted_masking: bool=False, - output_dir: Path=None): + spk_loss_scale: float=0.02, + output_dir: Path=None, + enable_spk_cls: bool=False): super().__init__(model, dataloader) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) @@ -120,6 +155,8 @@ class FastSpeech2Evaluator(StandardEvaluator): logger.addHandler(self.filehandler) self.logger = logger self.msg = "" + self.spk_loss_scale = spk_loss_scale + self.enable_spk_cls = enable_spk_cls self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) @@ -133,18 +170,33 @@ class FastSpeech2Evaluator(StandardEvaluator): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( - text=batch["text"], - text_lengths=batch["text_lengths"], - speech=batch["speech"], - speech_lengths=batch["speech_lengths"], - durations=batch["durations"], - pitch=batch["pitch"], - energy=batch["energy"], - spk_id=spk_id, - spk_emb=spk_emb) - - l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + + l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, @@ -155,19 +207,29 @@ class FastSpeech2Evaluator(StandardEvaluator): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens, ) - loss = l1_loss + duration_loss + pitch_loss + energy_loss + olens=olens, + spk_logits=spk_logits, + spk_ids=spk_id, ) + + scaled_speaker_loss = self.spk_loss_scale * speaker_loss + loss = l1_loss + duration_loss + pitch_loss + energy_loss + scaled_speaker_loss report("eval/loss", float(loss)) report("eval/l1_loss", float(l1_loss)) report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) report("eval/energy_loss", float(energy_loss)) + if self.enable_spk_cls: + report("train/speaker_loss", float(speaker_loss)) + report("train/scaled_speaker_loss", float(scaled_speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) + if self.enable_spk_cls: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scaled_speaker_loss"] = float(scaled_speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) diff --git a/paddlespeech/t2s/modules/adversarial_loss/__init__.py b/paddlespeech/t2s/modules/adversarial_loss/__init__.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/paddlespeech/t2s/modules/adversarial_loss/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py new file mode 100644 index 000000000..64da16053 --- /dev/null +++ b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +from paddle.autograd import PyLayer + + +class GradientReversalFunction(PyLayer): + """Gradient Reversal Layer from: + Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) + + Forward pass is the identity function. In the backward pass, + the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) + """ + + @staticmethod + def forward(ctx, x, lambda_=1): + """Forward in networks + """ + ctx.save_for_backward(lambda_) + return x.clone() + + @staticmethod + def backward(ctx, grads): + """Backward in networks + """ + lambda_, = ctx.saved_tensor() + dx = -lambda_ * grads + return paddle.clip(dx, min=-0.5, max=0.5) + + +class GradientReversalLayer(nn.Layer): + """Gradient Reversal Layer from: + Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) + + Forward pass is the identity function. In the backward pass, + the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) + """ + + def __init__(self, lambda_=1): + super(GradientReversalLayer, self).__init__() + self.lambda_ = lambda_ + + def forward(self, x): + """Forward in networks + """ + return GradientReversalFunction.apply(x, self.lambda_) diff --git a/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py new file mode 100644 index 000000000..d731b2d27 --- /dev/null +++ b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning) +import paddle +from paddle import nn +from typeguard import check_argument_types + + +class SpeakerClassifier(nn.Layer): + def __init__( + self, + idim: int, + hidden_sc_dim: int, + spk_num: int, ): + assert check_argument_types() + super().__init__() + # store hyperparameters + self.idim = idim + self.hidden_sc_dim = hidden_sc_dim + self.spk_num = spk_num + + self.model = nn.Sequential( + nn.Linear(self.idim, self.hidden_sc_dim), + nn.Linear(self.hidden_sc_dim, self.spk_num)) + + def parse_outputs(self, out, text_lengths): + mask = paddle.arange(out.shape[1]).expand( + [out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) + out = paddle.transpose(out, perm=[2, 0, 1]) + out = out * mask + out = paddle.transpose(out, perm=[1, 2, 0]) + return out + + def forward(self, encoder_outputs, text_lengths): + """ + encoder_outputs = [batch_size, seq_len, encoder_embedding_size] + text_lengths = [batch_size] + + log probabilities of speaker classification = [batch_size, seq_len, spk_num] + """ + + out = self.model(encoder_outputs) + out = self.parse_outputs(out, text_lengths) + return out