diff --git a/paddlespeech/t2s/modules/masked_fill.py b/paddlespeech/t2s/modules/masked_fill.py index 1445a926a..d143fe62f 100644 --- a/paddlespeech/t2s/modules/masked_fill.py +++ b/paddlespeech/t2s/modules/masked_fill.py @@ -29,7 +29,27 @@ def is_broadcastable(shp1, shp2): def broadcast_shape(shp1, shp2): result = [] for a, b in zip(shp1[::-1], shp2[::-1]): - result.append(max(a, b)) + is_a_int = isinstance(a, int) + is_b_int = isinstance(b, int) + + if is_a_int and is_b_int: + result.append(max(a, b)) + + else: + dtype = None + if hasattr(a, 'dtype'): + dtype = a.dtype + if hasattr(b, 'dtype'): + dtype = b.dtype + + if (is_a_int): + a = paddle.full((), a, dtype=dtype) + + if (is_b_int): + b = paddle.full((), b, dtype=dtype) + + result.append(paddle.maximum(a, b)) + return result[::-1] diff --git a/paddlespeech/t2s/modules/transformer/embedding.py b/paddlespeech/t2s/modules/transformer/embedding.py index f90eb44a4..e4331cff0 100644 --- a/paddlespeech/t2s/modules/transformer/embedding.py +++ b/paddlespeech/t2s/modules/transformer/embedding.py @@ -67,7 +67,7 @@ class PositionalEncoding(nn.Layer): pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) pe = pe.unsqueeze(0) - self.pe = pe + self.pe = paddle.assign(pe) def forward(self, x: paddle.Tensor): """Add positional encoding.