|
|
@ -110,14 +110,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
|
|
|
|
concat_after=concat_after, ) for _ in range(num_blocks)
|
|
|
|
concat_after=concat_after, ) for _ in range(num_blocks)
|
|
|
|
])
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
def forward(self,
|
|
|
|
self,
|
|
|
|
|
|
|
|
memory: paddle.Tensor,
|
|
|
|
memory: paddle.Tensor,
|
|
|
|
memory_mask: paddle.Tensor,
|
|
|
|
memory_mask: paddle.Tensor,
|
|
|
|
ys_in_pad: paddle.Tensor,
|
|
|
|
ys_in_pad: paddle.Tensor,
|
|
|
|
ys_in_lens: paddle.Tensor,
|
|
|
|
ys_in_lens: paddle.Tensor,
|
|
|
|
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
|
|
|
|
r_ys_in_pad: paddle.Tensor=paddle.empty([0]),
|
|
|
|
reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
|
|
reverse_weight: float=0.0
|
|
|
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
"""Forward decoder.
|
|
|
|
"""Forward decoder.
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
|
|
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
|
|