encoder.forward_one_step for prefix feat or ys

pull/930/head
Hui Zhang 3 years ago
parent 0cd30e48a9
commit 6a9daa800d

@ -383,7 +383,7 @@ class TransformerEncoder(BaseEncoder):
"""Encode input frame. """Encode input frame.
Args: Args:
xs (paddle.Tensor): Input tensor. (B, T, D) xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D)
masks (paddle.Tensor): Mask tensor. (B, T, T) masks (paddle.Tensor): Mask tensor. (B, T, T)
cache (List[paddle.Tensor]): List of cache tensors. cache (List[paddle.Tensor]): List of cache tensors.
@ -396,7 +396,6 @@ class TransformerEncoder(BaseEncoder):
xs = self.global_cmvn(xs) xs = self.global_cmvn(xs)
if isinstance(self.embed, Conv2dSubsampling): if isinstance(self.embed, Conv2dSubsampling):
# xs, masks = self.embed(xs, masks)
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
else: else:

Loading…
Cancel
Save