diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index bc2bdd1ac..0867e8158 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -252,7 +252,7 @@ class STExecutor(BaseExecutor): norm_feat = dict(kaldiio.load_ark(process.stdout))[utt_name] self._inputs["audio"] = paddle.to_tensor(norm_feat).unsqueeze(0) self._inputs["audio_len"] = paddle.to_tensor( - self._inputs["audio"].shape[1], dtype="int64") + self._inputs["audio"].shape[1:2], dtype="int64") else: raise ValueError("Wrong model type.") diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 4787e1eeb..beba7f602 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -491,7 +491,7 @@ class TTSExecutor(BaseExecutor): # multi speaker if am_dataset in {'aishell3', 'vctk', 'mix', 'canton'}: mel = self.am_inference( - part_phone_ids, spk_id=paddle.to_tensor(spk_id)) + part_phone_ids, spk_id=paddle.to_tensor([spk_id])) else: mel = self.am_inference(part_phone_ids) self.am_time += (time.time() - am_st) diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 8ce19795e..a95a9b288 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -783,7 +783,7 @@ class FastSpeech2(nn.Layer): x = paddle.cast(text, 'int64') d, p, e = durations, pitch, energy # setup batch axis - ilens = paddle.shape(x)[0] + ilens = paddle.shape(x)[0:1] xs = x.unsqueeze(0) diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 3d1b48dec..57c46e3a8 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -181,7 +181,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - bs = paddle.shape(lengths)[0] + bs = paddle.shape(lengths) if xs is None: maxlen = paddle.cast(lengths.max(), dtype=bs.dtype) else: