[Hackathon 7th] 修复 opencopop的svs1中的shape问题 (#3912)

* fix svs1

* fix

* fix

* fix

* fix

* add comment
pull/3939/head
cyberslack_lee 9 months ago committed by GitHub
parent d17361cf8c
commit a34bf501a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -120,6 +120,10 @@ class SinusoidalPosEmb(nn.Layer):
self.dim = dim self.dim = dim
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
# check if x is 0-dim tensor, if so, add a dimension
if x.ndim == 0:
x = paddle.cast(x.unsqueeze(0), 'float32')
else:
x = paddle.cast(x, 'float32') x = paddle.cast(x, 'float32')
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)

@ -181,11 +181,12 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0: if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
# check if ilens is 0-dim tensor, if so, add a dimension # check if lengths is 0-dim tensor, if so, add a dimension
if lengths.ndim == 0: if lengths.ndim == 0:
lengths = lengths.unsqueeze(0) bs = paddle.shape(lengths.unsqueeze(0))
else:
bs = paddle.shape(lengths) bs = paddle.shape(lengths)
if xs is None: if xs is None:
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype) maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else: else:

Loading…
Cancel
Save