rename spembs

pull/1003/head
TianYuan 3 years ago
parent 8d025451de
commit a97c7b5206

@ -35,7 +35,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--config=${config_path} \ --config=${config_path} \
--num-cpu=20 \ --num-cpu=20 \
--cut-sil=True \ --cut-sil=True \
--embed-dir=dump/embed --spk_emb_dir=dump/embed
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then

@ -3,7 +3,7 @@
set -e set -e
source path.sh source path.sh
gpus=0,1 gpus=4,5
stage=0 stage=0
stop_stage=100 stop_stage=100

@ -100,7 +100,7 @@ def fastspeech2_single_spk_batch_fn(examples):
def fastspeech2_multi_spk_batch_fn(examples): def fastspeech2_multi_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spembs"] # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples] text = [np.array(item["text"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
@ -139,14 +139,14 @@ def fastspeech2_multi_spk_batch_fn(examples):
"pitch": pitch, "pitch": pitch,
"energy": energy "energy": energy
} }
# spembs has a higher priority than spk_id # spk_emb has a higher priority than spk_id
if "spembs" in examples[0]: if "spk_emb" in examples[0]:
spembs = [ spk_emb = [
np.array(item["spembs"], dtype=np.float32) for item in examples np.array(item["spk_emb"], dtype=np.float32) for item in examples
] ]
spembs = batch_sequences(spembs) spk_emb = batch_sequences(spk_emb)
spembs = paddle.to_tensor(spembs) spk_emb = paddle.to_tensor(spk_emb)
batch["spembs"] = spembs batch["spk_emb"] = spk_emb
elif "spk_id" in examples[0]: elif "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id) spk_id = paddle.to_tensor(spk_id)

@ -167,9 +167,9 @@ def main():
"pitch": str(pitch_path), "pitch": str(pitch_path),
"energy": str(energy_path) "energy": str(energy_path)
} }
# add spembs for voice cloning # add spk_emb for voice cloning
if "spembs" in item: if "spk_emb" in item:
record["spembs"] = str(item["spembs"]) record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record) output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id')) output_metadata.sort(key=itemgetter('utt_id'))

@ -45,7 +45,7 @@ def process_sentence(config: Dict[str, Any],
pitch_extractor=None, pitch_extractor=None,
energy_extractor=None, energy_extractor=None,
cut_sil: bool=True, cut_sil: bool=True,
embed_dir: Path=None): spk_emb_dir: Path=None):
utt_id = fp.stem utt_id = fp.stem
# for vctk # for vctk
if utt_id.endswith("_mic2"): if utt_id.endswith("_mic2"):
@ -117,12 +117,12 @@ def process_sentence(config: Dict[str, Any],
"energy": str(energy_path), "energy": str(energy_path),
"speaker": speaker "speaker": speaker
} }
if embed_dir: if spk_emb_dir:
if speaker in os.listdir(embed_dir): if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy" embed_name = utt_id + ".npy"
embed_path = embed_dir / speaker / embed_name embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file(): if embed_path.is_file():
record["spembs"] = str(embed_path) record["spk_emb"] = str(embed_path)
else: else:
return None return None
return record return record
@ -137,13 +137,13 @@ def process_sentences(config,
energy_extractor=None, energy_extractor=None,
nprocs: int=1, nprocs: int=1,
cut_sil: bool=True, cut_sil: bool=True,
embed_dir: Path=None): spk_emb_dir: Path=None):
if nprocs == 1: if nprocs == 1:
results = [] results = []
for fp in fps: for fp in fps:
record = process_sentence(config, fp, sentences, output_dir, record = process_sentence(config, fp, sentences, output_dir,
mel_extractor, pitch_extractor, mel_extractor, pitch_extractor,
energy_extractor, cut_sil, embed_dir) energy_extractor, cut_sil, spk_emb_dir)
if record: if record:
results.append(record) results.append(record)
else: else:
@ -154,7 +154,7 @@ def process_sentences(config,
future = pool.submit(process_sentence, config, fp, future = pool.submit(process_sentence, config, fp,
sentences, output_dir, mel_extractor, sentences, output_dir, mel_extractor,
pitch_extractor, energy_extractor, pitch_extractor, energy_extractor,
cut_sil, embed_dir) cut_sil, spk_emb_dir)
future.add_done_callback(lambda p: progress.update()) future.add_done_callback(lambda p: progress.update())
futures.append(future) futures.append(future)
@ -213,7 +213,7 @@ def main():
help="whether cut sil in the edge of audio") help="whether cut sil in the edge of audio")
parser.add_argument( parser.add_argument(
"--embed-dir", "--spk_emb_dir",
default=None, default=None,
type=str, type=str,
help="directory to speaker embedding files.") help="directory to speaker embedding files.")
@ -226,10 +226,10 @@ def main():
dumpdir.mkdir(parents=True, exist_ok=True) dumpdir.mkdir(parents=True, exist_ok=True)
dur_file = Path(args.dur_file).expanduser() dur_file = Path(args.dur_file).expanduser()
if args.embed_dir: if args.spk_emb_dir:
embed_dir = Path(args.embed_dir).expanduser().resolve() spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
else: else:
embed_dir = None spk_emb_dir = None
assert rootdir.is_dir() assert rootdir.is_dir()
assert dur_file.is_file() assert dur_file.is_file()
@ -339,7 +339,7 @@ def main():
energy_extractor, energy_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
embed_dir=embed_dir) spk_emb_dir=spk_emb_dir)
if dev_wav_files: if dev_wav_files:
process_sentences( process_sentences(
config, config,
@ -350,7 +350,7 @@ def main():
pitch_extractor, pitch_extractor,
energy_extractor, energy_extractor,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
embed_dir=embed_dir) spk_emb_dir=spk_emb_dir)
if test_wav_files: if test_wav_files:
process_sentences( process_sentences(
config, config,
@ -362,7 +362,7 @@ def main():
energy_extractor, energy_extractor,
nprocs=args.num_cpu, nprocs=args.num_cpu,
cut_sil=args.cut_sil, cut_sil=args.cut_sil,
embed_dir=embed_dir) spk_emb_dir=spk_emb_dir)
if __name__ == "__main__": if __name__ == "__main__":

@ -49,7 +49,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
fields += ["spk_id"] fields += ["spk_id"]
elif args.voice_cloning: elif args.voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spembs"] fields += ["spk_emb"]
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
print("num_speakers:", num_speakers) print("num_speakers:", num_speakers)
@ -99,15 +99,15 @@ def evaluate(args, fastspeech2_config, pwg_config):
for datum in test_dataset: for datum in test_dataset:
utt_id = datum["utt_id"] utt_id = datum["utt_id"]
text = paddle.to_tensor(datum["text"]) text = paddle.to_tensor(datum["text"])
spembs = None spk_emb = None
spk_id = None spk_id = None
if args.voice_cloning and "spembs" in datum: if args.voice_cloning and "spk_emb" in datum:
spembs = paddle.to_tensor(np.load(datum["spembs"])) spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
elif "spk_id" in datum: elif "spk_id" in datum:
spk_id = paddle.to_tensor(datum["spk_id"]) spk_id = paddle.to_tensor(datum["spk_id"])
with paddle.no_grad(): with paddle.no_grad():
wav = pwg_inference( wav = pwg_inference(
fastspeech2_inference(text, spk_id=spk_id, spembs=spembs)) fastspeech2_inference(text, spk_id=spk_id, spk_emb=spk_emb))
sf.write( sf.write(
str(output_dir / (utt_id + ".wav")), str(output_dir / (utt_id + ".wav")),
wav.numpy(), wav.numpy(),

@ -73,8 +73,8 @@ def train_sp(args, config):
elif args.voice_cloning: elif args.voice_cloning:
print("Training voice cloning!") print("Training voice cloning!")
collate_fn = fastspeech2_multi_spk_batch_fn collate_fn = fastspeech2_multi_spk_batch_fn
fields += ["spembs"] fields += ["spk_emb"]
converters["spembs"] = np.load converters["spk_emb"] = np.load
else: else:
print("single speaker fastspeech2!") print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn collate_fn = fastspeech2_single_spk_batch_fn

@ -107,24 +107,25 @@ def voice_cloning(args, fastspeech2_config, pwg_config):
mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path)) mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
# print("mel_sequences: ", mel_sequences.shape) # print("mel_sequences: ", mel_sequences.shape)
with paddle.no_grad(): with paddle.no_grad():
spembs = speaker_encoder.embed_utterance( spk_emb = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences)) paddle.to_tensor(mel_sequences))
# print("spembs shape: ", spembs.shape) # print("spk_emb shape: ", spk_emb.shape)
with paddle.no_grad(): with paddle.no_grad():
wav = pwg_inference(fastspeech2_inference(phone_ids, spembs=spembs)) wav = pwg_inference(
fastspeech2_inference(phone_ids, spk_emb=spk_emb))
sf.write( sf.write(
str(output_dir / (utt_id + ".wav")), str(output_dir / (utt_id + ".wav")),
wav.numpy(), wav.numpy(),
samplerate=fastspeech2_config.fs) samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spembs # Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
random_spembs = np.random.rand(256) * 0.2 random_spk_emb = np.random.rand(256) * 0.2
random_spembs = paddle.to_tensor(random_spembs) random_spk_emb = paddle.to_tensor(random_spk_emb)
utt_id = "random_spembs" utt_id = "random_spk_emb"
with paddle.no_grad(): with paddle.no_grad():
wav = pwg_inference(fastspeech2_inference(phone_ids, spembs=spembs)) wav = pwg_inference(fastspeech2_inference(phone_ids, spk_emb=spk_emb))
sf.write( sf.write(
str(output_dir / (utt_id + ".wav")), str(output_dir / (utt_id + ".wav")),
wav.numpy(), wav.numpy(),

@ -297,7 +297,7 @@ class FastSpeech2(nn.Layer):
pitch: paddle.Tensor, pitch: paddle.Tensor,
energy: paddle.Tensor, energy: paddle.Tensor,
tone_id: paddle.Tensor=None, tone_id: paddle.Tensor=None,
spembs: paddle.Tensor=None, spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None spk_id: paddle.Tensor=None
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation. """Calculate forward propagation.
@ -320,7 +320,7 @@ class FastSpeech2(nn.Layer):
Batch of padded token-averaged energy (B, Tmax, 1). Batch of padded token-averaged energy (B, Tmax, 1).
tone_id : Tensor, optional(int64) tone_id : Tensor, optional(int64)
Batch of padded tone ids (B, Tmax). Batch of padded tone ids (B, Tmax).
spembs : Tensor, optional spk_emb : Tensor, optional
Batch of speaker embeddings (B, spk_embed_dim). Batch of speaker embeddings (B, spk_embed_dim).
spk_id : Tnesor, optional(int64) spk_id : Tnesor, optional(int64)
Batch of speaker ids (B,) Batch of speaker ids (B,)
@ -364,7 +364,7 @@ class FastSpeech2(nn.Layer):
ps, ps,
es, es,
is_inference=False, is_inference=False,
spembs=spembs, spk_emb=spk_emb,
spk_id=spk_id, spk_id=spk_id,
tone_id=tone_id) tone_id=tone_id)
# modify mod part of groundtruth # modify mod part of groundtruth
@ -385,7 +385,7 @@ class FastSpeech2(nn.Layer):
es: paddle.Tensor=None, es: paddle.Tensor=None,
is_inference: bool=False, is_inference: bool=False,
alpha: float=1.0, alpha: float=1.0,
spembs=None, spk_emb=None,
spk_id=None, spk_id=None,
tone_id=None) -> Sequence[paddle.Tensor]: tone_id=None) -> Sequence[paddle.Tensor]:
# forward encoder # forward encoder
@ -395,12 +395,12 @@ class FastSpeech2(nn.Layer):
# integrate speaker embedding # integrate speaker embedding
if self.spk_embed_dim is not None: if self.spk_embed_dim is not None:
# spembs has a higher priority than spk_id # spk_emb has a higher priority than spk_id
if spembs is not None: if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spembs) hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None: elif spk_id is not None:
spembs = self.spk_embedding_table(spk_id) spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spembs) hs = self._integrate_with_spk_embed(hs, spk_emb)
# integrate tone embedding # integrate tone embedding
if self.tone_embed_dim is not None: if self.tone_embed_dim is not None:
@ -488,7 +488,7 @@ class FastSpeech2(nn.Layer):
energy: paddle.Tensor=None, energy: paddle.Tensor=None,
alpha: float=1.0, alpha: float=1.0,
use_teacher_forcing: bool=False, use_teacher_forcing: bool=False,
spembs=None, spk_emb=None,
spk_id=None, spk_id=None,
tone_id=None, tone_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
@ -511,7 +511,7 @@ class FastSpeech2(nn.Layer):
use_teacher_forcing : bool, optional use_teacher_forcing : bool, optional
Whether to use teacher forcing. Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used. If true, groundtruth of duration, pitch and energy will be used.
spembs : Tensor, optional spk_emb : Tensor, optional
peaker embedding vector (spk_embed_dim,). peaker embedding vector (spk_embed_dim,).
spk_id : Tensor, optional(int64) spk_id : Tensor, optional(int64)
Batch of padded spk ids (1,). Batch of padded spk ids (1,).
@ -535,8 +535,8 @@ class FastSpeech2(nn.Layer):
if y is not None: if y is not None:
ys = y.unsqueeze(0) ys = y.unsqueeze(0)
if spembs is not None: if spk_emb is not None:
spembs = spembs.unsqueeze(0) spk_emb = spk_emb.unsqueeze(0)
if tone_id is not None: if tone_id is not None:
tone_id = tone_id.unsqueeze(0) tone_id = tone_id.unsqueeze(0)
@ -555,7 +555,7 @@ class FastSpeech2(nn.Layer):
ds=ds, ds=ds,
ps=ps, ps=ps,
es=es, es=es,
spembs=spembs, spk_emb=spk_emb,
spk_id=spk_id, spk_id=spk_id,
tone_id=tone_id, tone_id=tone_id,
is_inference=True) is_inference=True)
@ -567,19 +567,19 @@ class FastSpeech2(nn.Layer):
ys, ys,
is_inference=True, is_inference=True,
alpha=alpha, alpha=alpha,
spembs=spembs, spk_emb=spk_emb,
spk_id=spk_id, spk_id=spk_id,
tone_id=tone_id) tone_id=tone_id)
return outs[0], d_outs[0], p_outs[0], e_outs[0] return outs[0], d_outs[0], p_outs[0], e_outs[0]
def _integrate_with_spk_embed(self, hs, spembs): def _integrate_with_spk_embed(self, hs, spk_emb):
"""Integrate speaker embedding with hidden states. """Integrate speaker embedding with hidden states.
Parameters Parameters
---------- ----------
hs : Tensor hs : Tensor
Batch of hidden state sequences (B, Tmax, adim). Batch of hidden state sequences (B, Tmax, adim).
spembs : Tensor spk_emb : Tensor
Batch of speaker embeddings (B, spk_embed_dim). Batch of speaker embeddings (B, spk_embed_dim).
Returns Returns
@ -589,13 +589,13 @@ class FastSpeech2(nn.Layer):
""" """
if self.spk_embed_integration_type == "add": if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states # apply projection and then add to hidden states
spembs = self.spk_projection(F.normalize(spembs)) spk_emb = self.spk_projection(F.normalize(spk_emb))
hs = hs + spembs.unsqueeze(1) hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat": elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection # concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand( spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, hs.shape[1], -1]) shape=[-1, hs.shape[1], -1])
hs = self.spk_projection(paddle.concat([hs, spembs], axis=-1)) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else: else:
raise NotImplementedError("support only add or concat.") raise NotImplementedError("support only add or concat.")
@ -680,9 +680,9 @@ class FastSpeech2Inference(nn.Layer):
self.normalizer = normalizer self.normalizer = normalizer
self.acoustic_model = model self.acoustic_model = model
def forward(self, text, spk_id=None, spembs=None): def forward(self, text, spk_id=None, spk_emb=None):
normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference(
text, spk_id=spk_id, spembs=spembs) text, spk_id=spk_id, spk_emb=spk_emb)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)
return logmel return logmel

@ -54,9 +54,9 @@ class FastSpeech2Updater(StandardUpdater):
losses_dict = {} losses_dict = {}
# spk_id!=None in multiple spk fastspeech2 # spk_id!=None in multiple spk fastspeech2
spk_id = batch["spk_id"] if "spk_id" in batch else None spk_id = batch["spk_id"] if "spk_id" in batch else None
spembs = batch["spembs"] if "spembs" in batch else None spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
# No explicit speaker identifier labels are used during voice cloning training. # No explicit speaker identifier labels are used during voice cloning training.
if spembs is not None: if spk_emb is not None:
spk_id = None spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
@ -68,7 +68,7 @@ class FastSpeech2Updater(StandardUpdater):
pitch=batch["pitch"], pitch=batch["pitch"],
energy=batch["energy"], energy=batch["energy"],
spk_id=spk_id, spk_id=spk_id,
spembs=spembs) spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
after_outs=after_outs, after_outs=after_outs,
@ -131,8 +131,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
losses_dict = {} losses_dict = {}
# spk_id!=None in multiple spk fastspeech2 # spk_id!=None in multiple spk fastspeech2
spk_id = batch["spk_id"] if "spk_id" in batch else None spk_id = batch["spk_id"] if "spk_id" in batch else None
spembs = batch["spembs"] if "spembs" in batch else None spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
if spembs is not None: if spk_emb is not None:
spk_id = None spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
@ -144,7 +144,7 @@ class FastSpeech2Evaluator(StandardEvaluator):
pitch=batch["pitch"], pitch=batch["pitch"],
energy=batch["energy"], energy=batch["energy"],
spk_id=spk_id, spk_id=spk_id,
spembs=spembs) spk_emb=spk_emb)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
after_outs=after_outs, after_outs=after_outs,

@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer):
text_lengths: paddle.Tensor, text_lengths: paddle.Tensor,
speech: paddle.Tensor, speech: paddle.Tensor,
speech_lengths: paddle.Tensor, speech_lengths: paddle.Tensor,
spembs: paddle.Tensor=None, spk_emb: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation. """Calculate forward propagation.
@ -405,7 +405,7 @@ class TransformerTTS(nn.Layer):
Batch of padded target features (B, Lmax, odim). Batch of padded target features (B, Lmax, odim).
speech_lengths : Tensor(int64) speech_lengths : Tensor(int64)
Batch of the lengths of each target (B,). Batch of the lengths of each target (B,).
spembs : Tensor, optional spk_emb : Tensor, optional
Batch of speaker embeddings (B, spk_embed_dim). Batch of speaker embeddings (B, spk_embed_dim).
Returns Returns
@ -439,7 +439,7 @@ class TransformerTTS(nn.Layer):
# calculate transformer outputs # calculate transformer outputs
after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens,
spembs) spk_emb)
# modifiy mod part of groundtruth # modifiy mod part of groundtruth
@ -467,7 +467,7 @@ class TransformerTTS(nn.Layer):
ilens: paddle.Tensor, ilens: paddle.Tensor,
ys: paddle.Tensor, ys: paddle.Tensor,
olens: paddle.Tensor, olens: paddle.Tensor,
spembs: paddle.Tensor, spk_emb: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
# forward encoder # forward encoder
x_masks = self._source_mask(ilens) x_masks = self._source_mask(ilens)
@ -480,7 +480,7 @@ class TransformerTTS(nn.Layer):
# integrate speaker embedding # integrate speaker embedding
if self.spk_embed_dim is not None: if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs) hs = self._integrate_with_spk_embed(hs, spk_emb)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1: if self.reduction_factor > 1:
@ -514,7 +514,7 @@ class TransformerTTS(nn.Layer):
self, self,
text: paddle.Tensor, text: paddle.Tensor,
speech: paddle.Tensor=None, speech: paddle.Tensor=None,
spembs: paddle.Tensor=None, spk_emb: paddle.Tensor=None,
threshold: float=0.5, threshold: float=0.5,
minlenratio: float=0.0, minlenratio: float=0.0,
maxlenratio: float=10.0, maxlenratio: float=10.0,
@ -528,7 +528,7 @@ class TransformerTTS(nn.Layer):
Input sequence of characters (T,). Input sequence of characters (T,).
speech : Tensor, optional speech : Tensor, optional
Feature sequence to extract style (N, idim). Feature sequence to extract style (N, idim).
spembs : Tensor, optional spk_emb : Tensor, optional
Speaker embedding vector (spk_embed_dim,). Speaker embedding vector (spk_embed_dim,).
threshold : float, optional threshold : float, optional
Threshold in inference. Threshold in inference.
@ -551,7 +551,6 @@ class TransformerTTS(nn.Layer):
""" """
# input of embedding must be int64 # input of embedding must be int64
y = speech y = speech
spemb = spembs
# add eos at the last of sequence # add eos at the last of sequence
text = numpy.pad( text = numpy.pad(
@ -564,12 +563,12 @@ class TransformerTTS(nn.Layer):
# get teacher forcing outputs # get teacher forcing outputs
xs, ys = x.unsqueeze(0), y.unsqueeze(0) xs, ys = x.unsqueeze(0), y.unsqueeze(0)
spembs = None if spemb is None else spemb.unsqueeze(0) spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0)
ilens = paddle.to_tensor( ilens = paddle.to_tensor(
[xs.shape[1]], dtype=paddle.int64, place=xs.place) [xs.shape[1]], dtype=paddle.int64, place=xs.place)
olens = paddle.to_tensor( olens = paddle.to_tensor(
[ys.shape[1]], dtype=paddle.int64, place=ys.place) [ys.shape[1]], dtype=paddle.int64, place=ys.place)
outs, *_ = self._forward(xs, ilens, ys, olens, spembs) outs, *_ = self._forward(xs, ilens, ys, olens, spk_emb)
# get attention weights # get attention weights
att_ws = [] att_ws = []
@ -590,9 +589,9 @@ class TransformerTTS(nn.Layer):
hs = hs + style_embs.unsqueeze(1) hs = hs + style_embs.unsqueeze(1)
# integrate speaker embedding # integrate speaker embedding
if self.spk_embed_dim is not None: if spk_emb is not None:
spembs = spemb.unsqueeze(0) spk_emb = spk_emb.unsqueeze(0)
hs = self._integrate_with_spk_embed(hs, spembs) hs = self._integrate_with_spk_embed(hs, spk_emb)
# set limits of length # set limits of length
maxlen = int(hs.shape[1] * maxlenratio / self.reduction_factor) maxlen = int(hs.shape[1] * maxlenratio / self.reduction_factor)
@ -726,14 +725,14 @@ class TransformerTTS(nn.Layer):
def _integrate_with_spk_embed(self, def _integrate_with_spk_embed(self,
hs: paddle.Tensor, hs: paddle.Tensor,
spembs: paddle.Tensor) -> paddle.Tensor: spk_emb: paddle.Tensor) -> paddle.Tensor:
"""Integrate speaker embedding with hidden states. """Integrate speaker embedding with hidden states.
Parameters Parameters
---------- ----------
hs : Tensor hs : Tensor
Batch of hidden state sequences (B, Tmax, adim). Batch of hidden state sequences (B, Tmax, adim).
spembs : Tensor spk_emb : Tensor
Batch of speaker embeddings (B, spk_embed_dim). Batch of speaker embeddings (B, spk_embed_dim).
Returns Returns
@ -744,13 +743,13 @@ class TransformerTTS(nn.Layer):
""" """
if self.spk_embed_integration_type == "add": if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states # apply projection and then add to hidden states
spembs = self.projection(F.normalize(spembs)) spk_emb = self.projection(F.normalize(spk_emb))
hs = hs + spembs.unsqueeze(1) hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat": elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection # concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.shape[1], spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(-1, hs.shape[1],
-1) -1)
hs = self.projection(paddle.concat([hs, spembs], axis=-1)) hs = self.projection(paddle.concat([hs, spk_emb], axis=-1))
else: else:
raise NotImplementedError("support only add or concat.") raise NotImplementedError("support only add or concat.")

Loading…
Cancel
Save