|
|
@ -29,7 +29,27 @@ def is_broadcastable(shp1, shp2):
|
|
|
|
def broadcast_shape(shp1, shp2):
|
|
|
|
def broadcast_shape(shp1, shp2):
|
|
|
|
result = []
|
|
|
|
result = []
|
|
|
|
for a, b in zip(shp1[::-1], shp2[::-1]):
|
|
|
|
for a, b in zip(shp1[::-1], shp2[::-1]):
|
|
|
|
result.append(max(a, b))
|
|
|
|
is_a_int = isinstance(a, int)
|
|
|
|
|
|
|
|
is_b_int = isinstance(b, int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_a_int and is_b_int:
|
|
|
|
|
|
|
|
result.append(max(a, b))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
dtype = None
|
|
|
|
|
|
|
|
if hasattr(a, 'dtype'):
|
|
|
|
|
|
|
|
dtype = a.dtype
|
|
|
|
|
|
|
|
if hasattr(b, 'dtype'):
|
|
|
|
|
|
|
|
dtype = b.dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (is_a_int):
|
|
|
|
|
|
|
|
a = paddle.full((), a, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (is_b_int):
|
|
|
|
|
|
|
|
b = paddle.full((), b, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result.append(paddle.maximum(a, b))
|
|
|
|
|
|
|
|
|
|
|
|
return result[::-1]
|
|
|
|
return result[::-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|