Supprot dy2st for conformer

pull/2162/head
0x45f 3 years ago
parent e81849277e
commit 294b7b00bd

@ -159,9 +159,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor: def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
if convert_dtype_to_string(xs.dtype) == paddle.bool: if convert_dtype_to_string(xs.dtype) == paddle.bool:
xs = xs.astype(paddle.int) xs = xs.astype(paddle.int)
return xs.equal( return xs.equal(ys)
paddle.to_tensor(
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
if not hasattr(paddle.Tensor, 'eq'): if not hasattr(paddle.Tensor, 'eq'):
@ -219,13 +217,22 @@ def is_broadcastable(shp1, shp2):
return True return True
def broadcast_shape(shp1, shp2):
result = []
for a, b in zip(shp1[::-1], shp2[::-1]):
result.append(max(a, b))
return result[::-1]
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, bshape = broadcast_shape(xs.shape, mask.shape)
mask.shape) mask.stop_gradient = True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) tmp = paddle.ones(shape=[len(bshape)], dtype='int32')
mask = mask.broadcast_to(bshape) for index in range(len(bshape)):
tmp[index] = bshape[index]
mask = mask.broadcast_to(tmp)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs) xs = paddle.where(mask, trues, xs)
return xs return xs

@ -253,8 +253,8 @@ class BaseEncoder(nn.Layer):
# cnn_cache[i] = (B=1, hidden-dim, cache_t2) # cnn_cache[i] = (B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb, xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, att_cache=att_cache if elayers == 0 else att_cache[i:i+1],
cnn_cache=cnn_cache[i] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, cnn_cache=cnn_cache if paddle.shape(cnn_cache)[0] == 0 else cnn_cache[i],
) )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)

Loading…
Cancel
Save