|
|
|
@ -164,12 +164,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
|
|
|
|
|
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
|
|
|
|
|
masks = masks.astype(paddle.bool)
|
|
|
|
|
#TODO(Hui Zhang): mask_pad = ~masks
|
|
|
|
|
mask_pad = masks.logical_not()
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks, offset=0)
|
|
|
|
|
mask_pad = ~masks
|
|
|
|
|
chunk_masks = add_optional_chunk_mask(
|
|
|
|
|
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
|
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
@ -215,11 +211,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
same shape as the original cnn_cache
|
|
|
|
|
"""
|
|
|
|
|
assert xs.shape[0] == 1 # batch size must be one
|
|
|
|
|
# tmp_masks is just for interface compatibility
|
|
|
|
|
# TODO(Hui Zhang): stride_slice not support bool tensor
|
|
|
|
|
# tmp_masks = paddle.ones([1, paddle.shape(xs)[1]], dtype=paddle.bool)
|
|
|
|
|
tmp_masks = paddle.ones([1, xs.shape[1]], dtype=paddle.int32)
|
|
|
|
|
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
|
|
|
|
|
# tmp_masks is just for interface compatibility, [B=1, C=1, T]
|
|
|
|
|
tmp_masks = paddle.ones([1, 1, xs.shape[1]], dtype=paddle.bool)
|
|
|
|
|
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
@ -228,9 +221,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
|
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
|
|
|
|
|
# after embed, xs=(B=1, chunk_size, hidden-dim)
|
|
|
|
|
|
|
|
|
|
elayers = paddle.shape(att_cache)[0]
|
|
|
|
|
cache_t1 = paddle.shape(att_cache)[2]
|
|
|
|
|
chunk_size = paddle.shape(xs)[1]
|
|
|
|
|
elayers, _, cache_t1, _ = att_cache.shape
|
|
|
|
|
chunk_size = xs.shape[1]
|
|
|
|
|
attention_key_size = cache_t1 + chunk_size
|
|
|
|
|
|
|
|
|
|
# only used when using `RelPositionMultiHeadedAttention`
|
|
|
|
@ -402,11 +394,7 @@ class TransformerEncoder(BaseEncoder):
|
|
|
|
|
if self.global_cmvn is not None:
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
|
|
|
|
|
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
|
|
|
|
|
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
|
|
|
|
|
masks = masks.astype(paddle.bool)
|
|
|
|
|
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks, offset=0)
|
|
|
|
|
if cache is None:
|
|
|
|
|
cache = [None for _ in range(len(self.encoders))]
|
|
|
|
|
new_cache = []
|
|
|
|
|