From 96b8e42008661f7edd421fcf1bd075fdc834b781 Mon Sep 17 00:00:00 2001 From: liangym Date: Thu, 27 Oct 2022 08:14:01 +0000 Subject: [PATCH] add no_sync to fix Pylayer in DataParallel --- .../t2s/models/fastspeech2/fastspeech2.py | 184 ++++++++---------- .../models/fastspeech2/fastspeech2_updater.py | 50 ++--- 2 files changed, 113 insertions(+), 121 deletions(-) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index a87db3e8e..34a2ff98a 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -735,6 +735,89 @@ class FastSpeech2(nn.Layer): tone_id=tone_id) return hs + def inference( + self, + text: paddle.Tensor, + durations: paddle.Tensor=None, + pitch: paddle.Tensor=None, + energy: paddle.Tensor=None, + alpha: float=1.0, + use_teacher_forcing: bool=False, + spk_emb=None, + spk_id=None, + tone_id=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text(Tensor(int64)): + Input sequence of characters (T,). + durations(Tensor, optional (int64)): + Groundtruth of duration (T,). + pitch(Tensor, optional): + Groundtruth of token-averaged pitch (T, 1). + energy(Tensor, optional): + Groundtruth of token-averaged energy (T, 1). + alpha(float, optional): + Alpha to control the speed. + use_teacher_forcing(bool, optional): + Whether to use teacher forcing. + If true, groundtruth of duration, pitch and energy will be used. + spk_emb(Tensor, optional, optional): + peaker embedding vector (spk_embed_dim,). (Default value = None) + spk_id(Tensor, optional(int64), optional): + spk ids (1,). (Default value = None) + tone_id(Tensor, optional(int64), optional): + tone ids (T,). (Default value = None) + + Returns: + + """ + # input of embedding must be int64 + x = paddle.cast(text, 'int64') + d, p, e = durations, pitch, energy + # setup batch axis + ilens = paddle.shape(x)[0] + + xs = x.unsqueeze(0) + + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) + + if tone_id is not None: + tone_id = tone_id.unsqueeze(0) + + if use_teacher_forcing: + # use groundtruth of duration, pitch, and energy + ds = d.unsqueeze(0) if d is not None else None + ps = p.unsqueeze(0) if p is not None else None + es = e.unsqueeze(0) if e is not None else None + + # (1, L, odim) + _, outs, d_outs, p_outs, e_outs = self._inference( + xs, + ilens, + ds=ds, + ps=ps, + es=es, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id, + is_inference=True) + else: + # (1, L, odim) + _, outs, d_outs, p_outs, e_outs = self._inference( + xs, + ilens, + is_inference=True, + alpha=alpha, + spk_emb=spk_emb, + spk_id=spk_id, + tone_id=tone_id) + + + return outs[0], d_outs[0], p_outs[0], e_outs[0] + def _inference(self, xs: paddle.Tensor, ilens: paddle.Tensor, @@ -847,89 +930,6 @@ class FastSpeech2(nn.Layer): return before_outs, after_outs, d_outs, p_outs, e_outs - def inference( - self, - text: paddle.Tensor, - durations: paddle.Tensor=None, - pitch: paddle.Tensor=None, - energy: paddle.Tensor=None, - alpha: float=1.0, - use_teacher_forcing: bool=False, - spk_emb=None, - spk_id=None, - tone_id=None, - ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Generate the sequence of features given the sequences of characters. - - Args: - text(Tensor(int64)): - Input sequence of characters (T,). - durations(Tensor, optional (int64)): - Groundtruth of duration (T,). - pitch(Tensor, optional): - Groundtruth of token-averaged pitch (T, 1). - energy(Tensor, optional): - Groundtruth of token-averaged energy (T, 1). - alpha(float, optional): - Alpha to control the speed. - use_teacher_forcing(bool, optional): - Whether to use teacher forcing. - If true, groundtruth of duration, pitch and energy will be used. - spk_emb(Tensor, optional, optional): - peaker embedding vector (spk_embed_dim,). (Default value = None) - spk_id(Tensor, optional(int64), optional): - spk ids (1,). (Default value = None) - tone_id(Tensor, optional(int64), optional): - tone ids (T,). (Default value = None) - - Returns: - - """ - # input of embedding must be int64 - x = paddle.cast(text, 'int64') - d, p, e = durations, pitch, energy - # setup batch axis - ilens = paddle.shape(x)[0] - - xs = x.unsqueeze(0) - - if spk_emb is not None: - spk_emb = spk_emb.unsqueeze(0) - - if tone_id is not None: - tone_id = tone_id.unsqueeze(0) - - if use_teacher_forcing: - # use groundtruth of duration, pitch, and energy - ds = d.unsqueeze(0) if d is not None else None - ps = p.unsqueeze(0) if p is not None else None - es = e.unsqueeze(0) if e is not None else None - - # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( - xs, - ilens, - ds=ds, - ps=ps, - es=es, - spk_emb=spk_emb, - spk_id=spk_id, - tone_id=tone_id, - is_inference=True) - else: - # (1, L, odim) - _, outs, d_outs, p_outs, e_outs = self._inference( - xs, - ilens, - is_inference=True, - alpha=alpha, - spk_emb=spk_emb, - spk_id=spk_id, - tone_id=tone_id) - - - return outs[0], d_outs[0], p_outs[0], e_outs[0] - def _integrate_with_spk_embed(self, hs, spk_emb): """Integrate speaker embedding with hidden states. @@ -1272,6 +1272,7 @@ class FastSpeech2Loss(nn.Layer): es = es.masked_select(pitch_masks.broadcast_to(es.shape)) if spk_logits is not None and spk_ids is not None: + batch_size = spk_ids.shape[0] spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1], None) spk_logits = paddle.reshape(spk_logits, [-1, spk_logits.shape[-1]]) mask_index = spk_logits.abs().sum(axis=1)!=0 @@ -1288,22 +1289,7 @@ class FastSpeech2Loss(nn.Layer): energy_loss = self.mse_criterion(e_outs, es) if spk_logits is not None and spk_ids is not None: - speaker_loss = self.ce_criterion(spk_logits, spk_ids)/spk_logits.shape[0] - - # if spk_logits is None or spk_ids is None: - # speaker_loss = 0.0 - # else: - # Tmax = spk_logits.shape[1] - # batch_num = spk_logits.shape[0] - # spk_ids = - # speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_num - - # index_into_spkr_logits = batched_speakers.repeat_interleave(spkr_clsfir_logits.shape[1]) - # spkr_clsfir_logits = spkr_clsfir_logits.reshape(-1, spkr_clsfir_logits.shape[-1]) - # mask_index = spkr_clsfir_logits.abs().sum(dim=1)!=0 - # spkr_clsfir_logits = spkr_clsfir_logits[mask_index] - # index_into_spkr_logits = index_into_spkr_logits[mask_index] - # speaker_loss = self.ce_criterion(spkr_clsfir_logits, index_into_spkr_logits)/batched_speakers.shape[0] + speaker_loss = self.ce_criterion(spk_logits, spk_ids)/batch_size # make weighted mask and apply it if self.use_weighted_masking: diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 1eb0f60fd..7690a9cea 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -62,16 +62,17 @@ class FastSpeech2Updater(StandardUpdater): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( - text=batch["text"], - text_lengths=batch["text_lengths"], - speech=batch["speech"], - speech_lengths=batch["speech_lengths"], - durations=batch["durations"], - pitch=batch["pitch"], - energy=batch["energy"], - spk_id=spk_id, - spk_emb=spk_emb) + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, @@ -101,7 +102,7 @@ class FastSpeech2Updater(StandardUpdater): report("train/pitch_loss", float(pitch_loss)) report("train/energy_loss", float(energy_loss)) report("train/speaker_loss", float(speaker_loss)) - report("train/speaker_loss_0.02", float(self.spk_loss_scale * speaker_loss)) + report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) @@ -109,7 +110,7 @@ class FastSpeech2Updater(StandardUpdater): losses_dict["energy_loss"] = float(energy_loss) losses_dict["energy_loss"] = float(energy_loss) losses_dict["speaker_loss"] = float(speaker_loss) - losses_dict["speaker_loss_0.02"] = float(self.spk_loss_scale * speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items()) @@ -144,16 +145,17 @@ class FastSpeech2Evaluator(StandardEvaluator): if spk_emb is not None: spk_id = None - before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( - text=batch["text"], - text_lengths=batch["text_lengths"], - speech=batch["speech"], - speech_lengths=batch["speech_lengths"], - durations=batch["durations"], - pitch=batch["pitch"], - energy=batch["energy"], - spk_id=spk_id, - spk_emb=spk_emb) + with self.model.no_sync(): + before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + speech=batch["speech"], + speech_lengths=batch["speech_lengths"], + durations=batch["durations"], + pitch=batch["pitch"], + energy=batch["energy"], + spk_id=spk_id, + spk_emb=spk_emb) l1_loss, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion( after_outs=after_outs, @@ -176,11 +178,15 @@ class FastSpeech2Evaluator(StandardEvaluator): report("eval/duration_loss", float(duration_loss)) report("eval/pitch_loss", float(pitch_loss)) report("eval/energy_loss", float(energy_loss)) + report("train/speaker_loss", float(speaker_loss)) + report("train/scale_speaker_loss", float(self.spk_loss_scale * speaker_loss)) losses_dict["l1_loss"] = float(l1_loss) losses_dict["duration_loss"] = float(duration_loss) losses_dict["pitch_loss"] = float(pitch_loss) losses_dict["energy_loss"] = float(energy_loss) + losses_dict["speaker_loss"] = float(speaker_loss) + losses_dict["scale_speaker_loss"] = float(self.spk_loss_scale * speaker_loss) losses_dict["loss"] = float(loss) self.msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_dict.items())