diff --git a/examples/hey_snips/README.md b/examples/hey_snips/README.md index ba263906a..6311ad928 100644 --- a/examples/hey_snips/README.md +++ b/examples/hey_snips/README.md @@ -2,7 +2,7 @@ ## Metrics We mesure FRRs with fixing false alarms in one hour: - +the release model: https://paddlespeech.bj.bcebos.com/kws/heysnips/kws0_mdtc_heysnips_ckpt.tar.gz |Model|False Alarm| False Reject Rate| |--|--|--| |MDTC| 1| 0.003559 | diff --git a/examples/zh_en_tts/tts3/README.md b/examples/zh_en_tts/tts3/README.md index b4b683089..f63d5d8fe 100644 --- a/examples/zh_en_tts/tts3/README.md +++ b/examples/zh_en_tts/tts3/README.md @@ -116,6 +116,8 @@ optional arguments: 5. `--phones-dict` is the path of the phone vocabulary file. 6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker FastSpeech2. +We have **added module speaker classifier** with reference to [Learning to Speak Fluently in a Foreign Language: Multilingual Speech Synthesis and Cross-Language Voice Cloning](https://arxiv.org/pdf/1907.04448.pdf). The main parameter configuration: config["model"]["enable_speaker_classifier"], config["model"]["hidden_sc_dim"] and config["updater"]["spk_loss_scale"] in `conf/default.yaml`. The current experimental results show that this module can decouple text information and speaker information, and more experiments are still being sorted out. This module is currently not enabled by default, if you are interested, you can try it yourself. + ### Synthesizing We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the default neural vocoder. diff --git a/examples/zh_en_tts/tts3/conf/default.yaml b/examples/zh_en_tts/tts3/conf/default.yaml index 06bf1fcc1..efa8b3ea2 100644 --- a/examples/zh_en_tts/tts3/conf/default.yaml +++ b/examples/zh_en_tts/tts3/conf/default.yaml @@ -74,7 +74,7 @@ model: stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder spk_embed_dim: 256 # speaker embedding dimension spk_embed_integration_type: concat # speaker embedding integration type - enable_speaker_classifier: True # Whether to use speaker classifier module + enable_speaker_classifier: False # Whether to use speaker classifier module hidden_sc_dim: 256 # The hidden layer dim of speaker classifier diff --git a/examples/zh_en_tts/tts3/local/train.sh b/examples/zh_en_tts/tts3/local/train.sh index 3a5076505..1da72f117 100755 --- a/examples/zh_en_tts/tts3/local/train.sh +++ b/examples/zh_en_tts/tts3/local/train.sh @@ -8,6 +8,6 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ + --ngpu=2 \ --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt diff --git a/examples/zh_en_tts/tts3/run.sh b/examples/zh_en_tts/tts3/run.sh index a0c58f35c..12f99081a 100755 --- a/examples/zh_en_tts/tts3/run.sh +++ b/examples/zh_en_tts/tts3/run.sh @@ -3,9 +3,9 @@ set -e source path.sh -gpus=0 -stage=1 -stop_stage=1 +gpus=0,1 +stage=0 +stop_stage=100 datasets_root_dir=~/datasets mfa_root_dir=./mfa_results/ diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 34a2ff98a..0eb44beb6 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -25,6 +25,8 @@ import paddle.nn.functional as F from paddle import nn from typeguard import check_argument_types +from paddlespeech.t2s.modules.adversarial_loss.gradient_reversal import GradientReversalLayer +from paddlespeech.t2s.modules.adversarial_loss.speaker_classifier import SpeakerClassifier from paddlespeech.t2s.modules.nets_utils import initialize from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_pad_mask @@ -37,8 +39,6 @@ from paddlespeech.t2s.modules.transformer.encoder import CNNDecoder from paddlespeech.t2s.modules.transformer.encoder import CNNPostnet from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder -from paddlespeech.t2s.modules.multi_speakers.speaker_classifier import SpeakerClassifier -from paddlespeech.t2s.modules.multi_speakers.gradient_reversal import GradientReversalLayer class FastSpeech2(nn.Layer): @@ -140,10 +140,10 @@ class FastSpeech2(nn.Layer): # training related init_type: str="xavier_uniform", init_enc_alpha: float=1.0, - init_dec_alpha: float=1.0, + init_dec_alpha: float=1.0, # speaker classifier enable_speaker_classifier: bool=False, - hidden_sc_dim: int=256,): + hidden_sc_dim: int=256, ): """Initialize FastSpeech2 module. Args: idim (int): @@ -388,7 +388,8 @@ class FastSpeech2(nn.Layer): if self.spk_num and self.enable_speaker_classifier: # set lambda = 1 self.grad_reverse = GradientReversalLayer(1) - self.speaker_classifier = SpeakerClassifier(idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) + self.speaker_classifier = SpeakerClassifier( + idim=adim, hidden_sc_dim=self.hidden_sc_dim, spk_num=spk_num) # define duration predictor self.duration_predictor = DurationPredictor( @@ -601,7 +602,7 @@ class FastSpeech2(nn.Layer): # (B, Tmax, adim) hs, _ = self.encoder(xs, x_masks) - if self.spk_num and self.enable_speaker_classifier: + if self.spk_num and self.enable_speaker_classifier and not is_inference: hs_for_spk_cls = self.grad_reverse(hs) spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens) else: @@ -794,7 +795,7 @@ class FastSpeech2(nn.Layer): es = e.unsqueeze(0) if e is not None else None # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, ds=ds, @@ -806,7 +807,7 @@ class FastSpeech2(nn.Layer): is_inference=True) else: # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( + _, outs, d_outs, p_outs, e_outs, _ = self._forward( xs, ilens, is_inference=True, @@ -815,121 +816,8 @@ class FastSpeech2(nn.Layer): spk_id=spk_id, tone_id=tone_id) - return outs[0], d_outs[0], p_outs[0], e_outs[0] - def _inference(self, - xs: paddle.Tensor, - ilens: paddle.Tensor, - olens: paddle.Tensor=None, - ds: paddle.Tensor=None, - ps: paddle.Tensor=None, - es: paddle.Tensor=None, - is_inference: bool=False, - return_after_enc=False, - alpha: float=1.0, - spk_emb=None, - spk_id=None, - tone_id=None) -> Sequence[paddle.Tensor]: - # forward encoder - x_masks = self._source_mask(ilens) - # (B, Tmax, adim) - hs, _ = self.encoder(xs, x_masks) - - # integrate speaker embedding - if self.spk_embed_dim is not None: - # spk_emb has a higher priority than spk_id - if spk_emb is not None: - hs = self._integrate_with_spk_embed(hs, spk_emb) - elif spk_id is not None: - spk_emb = self.spk_embedding_table(spk_id) - hs = self._integrate_with_spk_embed(hs, spk_emb) - - # integrate tone embedding - if self.tone_embed_dim is not None: - if tone_id is not None: - tone_embs = self.tone_embedding_table(tone_id) - hs = self._integrate_with_tone_embed(hs, tone_embs) - # forward duration predictor and variance predictors - d_masks = make_pad_mask(ilens) - - if self.stop_gradient_from_pitch_predictor: - p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) - else: - p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) - if self.stop_gradient_from_energy_predictor: - e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) - else: - e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) - - if is_inference: - # (B, Tmax) - if ds is not None: - d_outs = ds - else: - d_outs = self.duration_predictor.inference(hs, d_masks) - if ps is not None: - p_outs = ps - if es is not None: - e_outs = es - - # use prediction in inference - # (B, Tmax, 1) - - p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs - - # (B, Lmax, adim) - hs = self.length_regulator(hs, d_outs, alpha, is_inference=True) - else: - d_outs = self.duration_predictor(hs, d_masks) - # use groundtruth in training - p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose( - (0, 2, 1)) - hs = hs + e_embs + p_embs - - # (B, Lmax, adim) - hs = self.length_regulator(hs, ds, is_inference=False) - - # forward decoder - if olens is not None and not is_inference: - if self.reduction_factor > 1: - olens_in = paddle.to_tensor( - [olen // self.reduction_factor for olen in olens.numpy()]) - else: - olens_in = olens - # (B, 1, T) - h_masks = self._source_mask(olens_in) - else: - h_masks = None - if return_after_enc: - return hs, h_masks - - if self.decoder_type == 'cnndecoder': - # remove output masks for dygraph to static graph - zs = self.decoder(hs, h_masks) - before_outs = zs - else: - # (B, Lmax, adim) - zs, _ = self.decoder(hs, h_masks) - # (B, Lmax, odim) - before_outs = self.feat_out(zs).reshape( - (paddle.shape(zs)[0], -1, self.odim)) - - # postnet -> (B, Lmax//r * r, odim) - if self.postnet is None: - after_outs = before_outs - else: - after_outs = before_outs + self.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - - return before_outs, after_outs, d_outs, p_outs, e_outs - def _integrate_with_spk_embed(self, hs, spk_emb): """Integrate speaker embedding with hidden states. @@ -1212,7 +1100,8 @@ class FastSpeech2Loss(nn.Layer): olens: paddle.Tensor, spk_logits: paddle.Tensor=None, spk_ids: paddle.Tensor=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,]: + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, ]: """Calculate forward propagation. Args: @@ -1249,7 +1138,7 @@ class FastSpeech2Loss(nn.Layer): """ speaker_loss = 0.0 - + # apply mask to remove padded part if self.use_masking: out_masks = make_non_pad_mask(olens).unsqueeze(-1) @@ -1273,12 +1162,13 @@ class FastSpeech2Loss(nn.Layer): if spk_logits is not None and spk_ids is not None: batch_size = spk_ids.shape[0] - spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], None) - spk_logits = paddle.reshape(spk_logits, [-1, spk_logits.shape[-1]]) - mask_index = spk_logits.abs().sum(axis=1)!=0 + spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], + None) + spk_logits = paddle.reshape(spk_logits, + [-1, spk_logits.shape[-1]]) + mask_index = spk_logits.abs().sum(axis=1) != 0 spk_ids = spk_ids[mask_index] spk_logits = spk_logits[mask_index] - # calculate loss l1_loss = self.l1_criterion(before_outs, ys) @@ -1289,7 +1179,7 @@ class FastSpeech2Loss(nn.Layer): energy_loss = self.mse_criterion(e_outs, es) if spk_logits is not None and spk_ids is not None: - speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_size + speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size # make weighted mask and apply it if self.use_weighted_masking: diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 7690a9cea..2b25b6a62 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -14,6 +14,7 @@ import logging from pathlib import Path +from paddle import DataParallel from paddle import distributed as dist from paddle.io import DataLoader from paddle.nn import Layer @@ -23,6 +24,7 @@ from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -43,7 +45,8 @@ class FastSpeech2Updater(StandardUpdater): super().__init__(model, optimizer, dataloader, init_state=None) self.criterion = FastSpeech2Loss( - use_masking=use_masking, use_weighted_masking=use_weighted_masking,) + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, ) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) self.filehandler = logging.FileHandler(str(log_file)) @@ -62,7 +65,21 @@ class FastSpeech2Updater(StandardUpdater): if spk_emb is not None: spk_id = None - with self.model.no_sync(): + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], @@ -87,7 +104,7 @@ class FastSpeech2Updater(StandardUpdater): ilens=batch["text_lengths"], olens=olens, spk_logits=spk_logits, - spk_ids=spk_id,) + spk_ids=spk_id, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss @@ -101,16 +118,20 @@ class FastSpeech2Updater(StandardUpdater): report("train/duration_loss", float(duration_loss)) report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) - report("train/speaker_loss", float(speaker_loss)) - report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) + if speaker_loss != 0.0: + report("train/speaker_loss", float(speaker_loss)) + report("train/scale_speaker_loss", + float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss) - losses_dict["speaker_loss"] = float(speaker_loss) - losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) + if speaker_loss != 0.0: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * + speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) @@ -145,7 +166,21 @@ class FastSpeech2Evaluator(StandardEvaluator): if spk_emb is not None: spk_id = None - with self.model.no_sync(): + if type( + self.model + ) == DataParallel and self.model._layers.spk_num and self.model._layers.enable_speaker_classifier: + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) + else: before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( text=batch["text"], text_lengths=batch["text_lengths"], @@ -168,9 +203,9 @@ class FastSpeech2Evaluator(StandardEvaluator): ps=batch["pitch"], es=batch["energy"], ilens=batch["text_lengths"], - olens=olens, + olens=olens, spk_logits=spk_logits, - spk_ids=spk_id,) + spk_ids=spk_id, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss + self.spk_loss_scale * speaker_loss report("eval/loss", float(loss)) @@ -178,15 +213,19 @@ class FastSpeech2Evaluator(StandardEvaluator): report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) report("eval/energy_loss", float(energy_loss)) - report("train/speaker_loss", float(speaker_loss)) - report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) + if speaker_loss != 0.0: + report("train/speaker_loss", float(speaker_loss)) + report("train/scale_speaker_loss", + float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) - losses_dict["speaker_loss"] = float(speaker_loss) - losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) + if speaker_loss != 0.0: + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * + speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) diff --git a/paddlespeech/t2s/modules/multi_speakers/__init__.py b/paddlespeech/t2s/modules/adversarial_loss/__init__.py similarity index 100% rename from paddlespeech/t2s/modules/multi_speakers/__init__.py rename to paddlespeech/t2s/modules/adversarial_loss/__init__.py diff --git a/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py similarity index 99% rename from paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py rename to paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py index 5250f1df1..e98758099 100644 --- a/paddlespeech/t2s/modules/multi_speakers/gradient_reversal.py +++ b/paddlespeech/t2s/modules/adversarial_loss/gradient_reversal.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import paddle -from paddle.autograd import PyLayer import paddle.nn as nn +from paddle.autograd import PyLayer + class GradientReversalFunction(PyLayer): """Gradient Reversal Layer from: @@ -57,4 +57,3 @@ class GradientReversalLayer(nn.Layer): """Forward in networks """ return GradientReversalFunction.apply(x, self.lambda_) - diff --git a/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py similarity index 78% rename from paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py rename to paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py index a64f6d5b8..d731b2d27 100644 --- a/paddlespeech/t2s/modules/multi_speakers/speaker_classifier.py +++ b/paddlespeech/t2s/modules/adversarial_loss/speaker_classifier.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from Cross-Lingual-Voice-Cloning(https://github.com/deterministic-algorithms-lab/Cross-Lingual-Voice-Cloning) - -from paddle import nn import paddle +from paddle import nn from typeguard import check_argument_types + class SpeakerClassifier(nn.Layer): - - def __init__(self, idim: int, hidden_sc_dim: int, spk_num: int, ): + def __init__( + self, + idim: int, + hidden_sc_dim: int, + spk_num: int, ): assert check_argument_types() super().__init__() # store hyperparameters @@ -27,11 +30,13 @@ class SpeakerClassifier(nn.Layer): self.hidden_sc_dim = hidden_sc_dim self.spk_num = spk_num - self.model = nn.Sequential(nn.Linear(self.idim, self.hidden_sc_dim), - nn.Linear(self.hidden_sc_dim, self.spk_num)) - + self.model = nn.Sequential( + nn.Linear(self.idim, self.hidden_sc_dim), + nn.Linear(self.hidden_sc_dim, self.spk_num)) + def parse_outputs(self, out, text_lengths): - mask = paddle.arange(out.shape[1]).expand([out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) + mask = paddle.arange(out.shape[1]).expand( + [out.shape[0], out.shape[1]]) < text_lengths.unsqueeze(1) out = paddle.transpose(out, perm=[2, 0, 1]) out = out * mask out = paddle.transpose(out, perm=[1, 2, 0]) @@ -44,7 +49,7 @@ class SpeakerClassifier(nn.Layer): log probabilities of speaker classification = [batch_size, seq_len, spk_num] """ - - out = self.model(encoder_outputs) + + out = self.model(encoder_outputs) out = self.parse_outputs(out, text_lengths) return out