fix format test=asr

pull/2336/head
tianhao zhang 2 years ago
parent 1dfca4ef73
commit ed2819d7af

@ -195,8 +195,7 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle. mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool)
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:

@ -476,8 +476,12 @@ class PaddleASRConnectionHanddler:
# forward chunk # forward chunk
(y, self.att_cache, (y, self.att_cache,
self.cnn_cache) = self.model.encoder.forward_chunk( self.cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size, self.att_cache, chunk_xs,
self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool)) self.offset,
required_cache_size,
att_cache=self.att_cache,
cnn_cache=self.cnn_cache,
att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y) outputs.append(y)
# update the global offset, in decoding frame unit # update the global offset, in decoding frame unit

Loading…
Cancel
Save