Merge pull request #2327 from Zth9730/fix_multigpu_train

[s2t] fix conformer/transformer multi-gpu training, maybe impact dy2st
pull/2347/head
Hui Zhang 2 years ago committed by GitHub
commit 94e750c4c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.

@ -15,7 +15,6 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition."""
import math
from typing import Optional
from typing import Tuple
import paddle
@ -83,10 +82,11 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v
def forward_attention(self,
def forward_attention(
self,
value: paddle.Tensor,
scores: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
) -> paddle.Tensor:
"""Compute attention context vector.
Args:
@ -127,13 +127,14 @@ class MultiHeadedAttention(nn.Layer):
return self.linear_out(x) # (batch, time1, d_model)
def forward(self,
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
@ -243,13 +244,14 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return x
def forward(self,
def forward(
self,
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor = paddle.zeros([0,0,0,0])
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
pos_emb: paddle.Tensor, # paddle.empty([0])
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:

@ -14,7 +14,6 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""ConvolutionModule definition."""
from typing import Optional
from typing import Tuple
import paddle
@ -106,10 +105,11 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
def forward(self,
def forward(
self,
x: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
cache: paddle.Tensor= paddle.zeros([0,0,0]),
mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
cache: paddle.Tensor # paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:

@ -121,11 +121,16 @@ class DecoderLayer(nn.Layer):
if self.concat_after:
tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before:
x = self.norm1(x)
@ -134,11 +139,15 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x)
if self.concat_after:
x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
(x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before:
x = self.norm2(x)

@ -14,8 +14,6 @@
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import paddle
@ -177,7 +175,9 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks)
for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool)
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
@ -252,10 +252,12 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
)
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk(
self,
xs: paddle.Tensor,
@ -326,7 +327,8 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y)
offset += y.shape[1]

@ -76,9 +76,10 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:
@ -105,7 +106,8 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before:
x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1)
@ -193,9 +195,10 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor,
mask: paddle.Tensor,
pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
mask_pad: paddle.
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool)
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features.
Args:

Loading…
Cancel
Save