fix multigpu training test=asr

pull/2327/head
tianhao zhang 2 years ago
parent 733ec7f2bc
commit ed80b0e2c3

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor,
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.

@ -86,7 +86,8 @@ class MultiHeadedAttention(nn.Layer):
self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor, ) -> paddle.Tensor:
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> paddle.Tensor:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
@ -126,13 +127,15 @@ class MultiHeadedAttention(nn.Layer):
return self.linear_out(x) # (batch, time1, d_model)
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -241,13 +244,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return x
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -105,10 +105,12 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
def forward(self,
x: paddle.Tensor,
mask_pad: paddle.Tensor,
cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).

@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor,
att_mask: paddle.Tensor,
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:

@ -76,9 +76,10 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor,
att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor,
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
@ -194,9 +195,10 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor,
att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor,
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:

Loading…
Cancel
Save