fix conformer multi-gpu training test=asr

pull/2327/head
tianhao zhang 2 years ago
parent ed16f96a9c
commit 733ec7f2bc

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.

@ -15,7 +15,6 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition.""" """Multi-Head Attention layer definition."""
import math import math
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -83,11 +82,11 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v return q, k, v
def forward_attention(self, def forward_attention(
self,
value: paddle.Tensor, value: paddle.Tensor,
scores: paddle.Tensor, scores: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), mask: paddle.Tensor, ) -> paddle.Tensor:
) -> paddle.Tensor:
"""Compute attention context vector. """Compute attention context vector.
Args: Args:
value (paddle.Tensor): Transformed value, size value (paddle.Tensor): Transformed value, size
@ -131,10 +130,9 @@ class MultiHeadedAttention(nn.Layer):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), mask: paddle.Tensor,
pos_emb: paddle.Tensor = paddle.empty([0]), pos_emb: paddle.Tensor,
cache: paddle.Tensor = paddle.zeros([0,0,0,0]) cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention. """Compute scaled dot product attention.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
@ -247,10 +245,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), mask: paddle.Tensor,
pos_emb: paddle.Tensor = paddle.empty([0]), pos_emb: paddle.Tensor,
cache: paddle.Tensor = paddle.zeros([0,0,0,0]) cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:
query (paddle.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""ConvolutionModule definition.""" """ConvolutionModule definition."""
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -108,9 +107,8 @@ class ConvolutionModule(nn.Layer):
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), mask_pad: paddle.Tensor,
cache: paddle.Tensor= paddle.zeros([0,0,0]), cache: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
x (paddle.Tensor): Input tensor (#batch, time, channels). x (paddle.Tensor): Input tensor (#batch, time, channels).

@ -121,11 +121,16 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat) x = residual + self.concat_linear1(tgt_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) self.self_attn(tgt_q, tgt, tgt, tgt_q_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -134,11 +139,15 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x) x = self.norm2(x)
if self.concat_after: if self.concat_after:
x_concat = paddle.cat( x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1) (x, self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat) x = residual + self.concat_linear2(x_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0]) self.src_attn(x, memory, memory, memory_mask,
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Encoder definition.""" """Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -177,7 +175,9 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks) num_decoding_left_chunks)
for layer in self.encoders: for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), cnn_cache: paddle.Tensor,
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), att_mask: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk """ Forward just one chunk
Args: Args:
@ -252,10 +252,12 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb, xs,
att_mask,
pos_emb,
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, cnn_cache=cnn_cache[i:i + 1]
) if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -326,7 +327,8 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :] chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk( (y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache) chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y) outputs.append(y)
offset += y.shape[1] offset += y.shape[1]

@ -76,9 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), mask_pad: paddle.Tensor,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
@ -105,7 +105,8 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache) x_att, new_att_cache = self.self_attn(
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)
@ -193,9 +194,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), mask_pad: paddle.Tensor,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor,
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:

Loading…
Cancel
Save