From 7fd5abd75d16852c5d6c1ab385b37603d9be7c77 Mon Sep 17 00:00:00 2001 From: megemini Date: Mon, 25 Nov 2024 11:25:37 +0800 Subject: [PATCH] [Fix] max between int and value (#3903) --- paddlespeech/t2s/modules/masked_fill.py | 22 ++++++++++++++++++- .../t2s/modules/transformer/embedding.py | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/paddlespeech/t2s/modules/masked_fill.py b/paddlespeech/t2s/modules/masked_fill.py index 1445a926..d143fe62 100644 --- a/paddlespeech/t2s/modules/masked_fill.py +++ b/paddlespeech/t2s/modules/masked_fill.py @@ -29,7 +29,27 @@ def is_broadcastable(shp1, shp2): def broadcast_shape(shp1, shp2): result = [] for a, b in zip(shp1[::-1], shp2[::-1]): - result.append(max(a, b)) + is_a_int = isinstance(a, int) + is_b_int = isinstance(b, int) + + if is_a_int and is_b_int: + result.append(max(a, b)) + + else: + dtype = None + if hasattr(a, 'dtype'): + dtype = a.dtype + if hasattr(b, 'dtype'): + dtype = b.dtype + + if (is_a_int): + a = paddle.full((), a, dtype=dtype) + + if (is_b_int): + b = paddle.full((), b, dtype=dtype) + + result.append(paddle.maximum(a, b)) + return result[::-1] diff --git a/paddlespeech/t2s/modules/transformer/embedding.py b/paddlespeech/t2s/modules/transformer/embedding.py index f90eb44a..e4331cff 100644 --- a/paddlespeech/t2s/modules/transformer/embedding.py +++ b/paddlespeech/t2s/modules/transformer/embedding.py @@ -67,7 +67,7 @@ class PositionalEncoding(nn.Layer): pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term) pe = pe.unsqueeze(0) - self.pe = pe + self.pe = paddle.assign(pe) def forward(self, x: paddle.Tensor): """Add positional encoding.