diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index a9b5e8a6..0f8f1075 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -383,7 +383,7 @@ class TransformerEncoder(BaseEncoder): """Encode input frame. Args: - xs (paddle.Tensor): Input tensor. (B, T, D) + xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D) masks (paddle.Tensor): Mask tensor. (B, T, T) cache (List[paddle.Tensor]): List of cache tensors. @@ -396,7 +396,6 @@ class TransformerEncoder(BaseEncoder): xs = self.global_cmvn(xs) if isinstance(self.embed, Conv2dSubsampling): - # xs, masks = self.embed(xs, masks) #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) else: