|
|
|
@ -383,7 +383,7 @@ class TransformerEncoder(BaseEncoder):
|
|
|
|
|
"""Encode input frame.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
xs (paddle.Tensor): Input tensor. (B, T, D)
|
|
|
|
|
xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D)
|
|
|
|
|
masks (paddle.Tensor): Mask tensor. (B, T, T)
|
|
|
|
|
cache (List[paddle.Tensor]): List of cache tensors.
|
|
|
|
|
|
|
|
|
@ -396,7 +396,6 @@ class TransformerEncoder(BaseEncoder):
|
|
|
|
|
xs = self.global_cmvn(xs)
|
|
|
|
|
|
|
|
|
|
if isinstance(self.embed, Conv2dSubsampling):
|
|
|
|
|
# xs, masks = self.embed(xs, masks)
|
|
|
|
|
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
|
|
|
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
|
|
|
|
|
else:
|
|
|
|
|