Merge pull request #1516 from jerryuhoo/fix_speedyspeech

[TTS] fix Speedyspeech multi-speaker inference, test=tts
pull/1520/head
TianYuan 3 years ago committed by GitHub
commit dc1dc04536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -194,10 +194,10 @@ def evaluate(args):
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
InputSpec([-1], dtype=paddle.int64), # text InputSpec([-1], dtype=paddle.int64), # text
InputSpec([-1], dtype=paddle.int64), # tone InputSpec([-1], dtype=paddle.int64), # tone
None, # duration InputSpec([1], dtype=paddle.int64), # spk_id
InputSpec([-1], dtype=paddle.int64) # spk_id None # duration
]) ])
else: else:
am_inference = jit.to_static( am_inference = jit.to_static(

@ -247,7 +247,7 @@ class SpeedySpeechInference(nn.Layer):
self.normalizer = normalizer self.normalizer = normalizer
self.acoustic_model = speedyspeech_model self.acoustic_model = speedyspeech_model
def forward(self, phones, tones, durations=None, spk_id=None): def forward(self, phones, tones, spk_id=None, durations=None):
normalized_mel = self.acoustic_model.inference( normalized_mel = self.acoustic_model.inference(
phones, tones, durations=durations, spk_id=spk_id) phones, tones, durations=durations, spk_id=spk_id)
logmel = self.normalizer.inverse(normalized_mel) logmel = self.normalizer.inverse(normalized_mel)

Loading…
Cancel
Save