[Fix] max between int and value (#3903)

pull/3918/head
megemini 2 months ago committed by GitHub
parent afa9466c89
commit 7fd5abd75d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,7 +29,27 @@ def is_broadcastable(shp1, shp2):
def broadcast_shape(shp1, shp2): def broadcast_shape(shp1, shp2):
result = [] result = []
for a, b in zip(shp1[::-1], shp2[::-1]): for a, b in zip(shp1[::-1], shp2[::-1]):
is_a_int = isinstance(a, int)
is_b_int = isinstance(b, int)
if is_a_int and is_b_int:
result.append(max(a, b)) result.append(max(a, b))
else:
dtype = None
if hasattr(a, 'dtype'):
dtype = a.dtype
if hasattr(b, 'dtype'):
dtype = b.dtype
if (is_a_int):
a = paddle.full((), a, dtype=dtype)
if (is_b_int):
b = paddle.full((), b, dtype=dtype)
result.append(paddle.maximum(a, b))
return result[::-1] return result[::-1]

@ -67,7 +67,7 @@ class PositionalEncoding(nn.Layer):
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
self.pe = pe self.pe = paddle.assign(pe)
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
"""Add positional encoding. """Add positional encoding.

Loading…
Cancel
Save