Merge pull request #2327 from Zth9730/fix_multigpu_train

[s2t] fix conformer/transformer multi-gpu training, maybe impact dy2st
pull/2347/head
Hui Zhang 2 years ago committed by GitHub
commit 94e750c4c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.

@ -15,7 +15,6 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition."""
import math
from typing import Optional
from typing import Tuple
import paddle
@ -83,11 +82,12 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v
def forward_attention(self,
value: paddle.Tensor,
def forward_attention(
self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> paddle.Tensor:
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> paddle.Tensor:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
@ -108,7 +108,7 @@ class MultiHeadedAttention(nn.Layer):
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
if paddle.shape(mask)[2] > 0: # time2 > 0
if paddle.shape(mask)[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
@ -127,14 +127,15 @@ class MultiHeadedAttention(nn.Layer):
return self.linear_out(x) # (batch, time1, d_model)
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -243,14 +244,15 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return x
def forward(self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -14,7 +14,6 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""ConvolutionModule definition."""
from typing import Optional
from typing import Tuple
import paddle
@ -106,11 +105,12 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
def forward(self,
x: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
cache: paddle.Tensor= paddle.zeros([0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]:
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
@ -127,11 +127,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0:
if paddle.shape(cache)[2] == 0: # cache_t == 0
if paddle.shape(cache)[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
@ -161,7 +161,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x)
# mask batch padding
if paddle.shape(mask_pad)[2] > 0: # time > 0
if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]

@ -121,11 +121,16 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before:
x = self.norm1(x)
@ -134,11 +139,15 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x)
if self.concat_after:
x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
(x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before:
x = self.norm2(x)

@ -14,8 +14,6 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import paddle
@ -177,7 +175,9 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
@ -227,7 +227,7 @@ class BaseEncoder(nn.Layer):
xs = self.global_cmvn(xs)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers = paddle.shape(att_cache)[0]
@ -252,14 +252,16 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
)
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before:
xs = self.after_norm(xs)
@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk(
self,
xs: paddle.Tensor,
@ -315,8 +316,8 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
outputs = []
offset = 0
@ -326,7 +327,8 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y)
offset += y.shape[1]

@ -76,9 +76,10 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
@ -105,7 +106,8 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
@ -193,9 +195,10 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:

Loading…
Cancel
Save