diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 3aeebd29b..54324c2f6 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -89,7 +89,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.dropout = nn.Dropout(p=dropout_rate) - self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D] + self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D] position = paddle.arange( 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] @@ -97,9 +97,8 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * -(math.log(10000.0) / self.d_model)) - self.pe[:, 0::2] = paddle.sin(position * div_term) - self.pe[:, 1::2] = paddle.cos(position * div_term) - self.pe = self.pe.unsqueeze(0) #[1, T, D] + self.pe[:, :, 0::2] = paddle.sin(position * div_term) + self.pe[:, :, 1::2] = paddle.cos(position * div_term) def forward(self, x: paddle.Tensor, offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 87b83ef55..2e76ccb05 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -264,15 +264,15 @@ class BaseEncoder(nn.Layer): # new_att_cache = (1, head, attention_key_size, d_k*2) # new_cnn_cache = (B=1, hidden-dim, cache_t2) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim + r_cnn_cache.append(new_cnn_cache) # add elayer dim if self.normalize_before: xs = self.after_norm(xs) # r_att_cache (elayers, head, T, d_k*2) - # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2) + # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2) r_att_cache = paddle.concat(r_att_cache, axis=0) - r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) + r_cnn_cache = paddle.stack(r_cnn_cache, axis=0) return xs, r_att_cache, r_cnn_cache def forward_chunk_by_chunk(