|
|
@ -14,8 +14,6 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
|
|
"""Encoder definition."""
|
|
|
|
"""Encoder definition."""
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
from typing import Tuple
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
@ -177,7 +175,9 @@ class BaseEncoder(nn.Layer):
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
decoding_chunk_size, self.static_chunk_size,
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
num_decoding_left_chunks)
|
|
|
|
for layer in self.encoders:
|
|
|
|
for layer in self.encoders:
|
|
|
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
|
|
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad,
|
|
|
|
|
|
|
|
paddle.zeros([0, 0, 0, 0]),
|
|
|
|
|
|
|
|
paddle.zeros([0, 0, 0, 0]))
|
|
|
|
if self.normalize_before:
|
|
|
|
if self.normalize_before:
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
xs = self.after_norm(xs)
|
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
|
# Here we assume the mask is not changed in encoder layers, so just
|
|
|
@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
offset: int,
|
|
|
|
offset: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
required_cache_size: int,
|
|
|
|
att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
|
|
|
|
att_cache: paddle.Tensor,
|
|
|
|
cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
|
|
|
|
cnn_cache: paddle.Tensor,
|
|
|
|
att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
|
|
|
|
att_mask: paddle.Tensor,
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
""" Forward just one chunk
|
|
|
|
""" Forward just one chunk
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -252,10 +252,12 @@ class BaseEncoder(nn.Layer):
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
|
|
|
|
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
|
|
|
|
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
xs, _, new_att_cache, new_cnn_cache = layer(
|
|
|
|
xs, att_mask, pos_emb,
|
|
|
|
xs,
|
|
|
|
|
|
|
|
att_mask,
|
|
|
|
|
|
|
|
pos_emb,
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
|
|
|
cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
|
|
|
|
cnn_cache=cnn_cache[i:i + 1]
|
|
|
|
)
|
|
|
|
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_att_cache = (1, head, attention_key_size, d_k*2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
|
|
|
|
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
|
|
|
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
|
|
@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer):
|
|
|
|
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
|
|
|
|
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
|
|
|
|
return xs, r_att_cache, r_cnn_cache
|
|
|
|
return xs, r_att_cache, r_cnn_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_chunk_by_chunk(
|
|
|
|
def forward_chunk_by_chunk(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
xs: paddle.Tensor,
|
|
|
|
xs: paddle.Tensor,
|
|
|
@ -326,7 +327,8 @@ class BaseEncoder(nn.Layer):
|
|
|
|
chunk_xs = xs[:, cur:end, :]
|
|
|
|
chunk_xs = xs[:, cur:end, :]
|
|
|
|
|
|
|
|
|
|
|
|
(y, att_cache, cnn_cache) = self.forward_chunk(
|
|
|
|
(y, att_cache, cnn_cache) = self.forward_chunk(
|
|
|
|
chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
|
|
|
|
chunk_xs, offset, required_cache_size, att_cache, cnn_cache,
|
|
|
|
|
|
|
|
paddle.ones([0, 0, 0], dtype=paddle.bool))
|
|
|
|
|
|
|
|
|
|
|
|
outputs.append(y)
|
|
|
|
outputs.append(y)
|
|
|
|
offset += y.shape[1]
|
|
|
|
offset += y.shape[1]
|
|
|
|