eliminate useless unsqueese

pull/2425/head
Hui Zhang 2 years ago
parent 7382050e21
commit b7388ce25a

@ -89,7 +89,7 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
self.dropout = nn.Dropout(p=dropout_rate) self.dropout = nn.Dropout(p=dropout_rate)
self.pe = paddle.zeros([self.max_len, self.d_model]) #[T,D] self.pe = paddle.zeros([1, self.max_len, self.d_model]) #[B=1,T,D]
position = paddle.arange( position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1] 0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
@ -97,9 +97,8 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) * paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model)) -(math.log(10000.0) / self.d_model))
self.pe[:, 0::2] = paddle.sin(position * div_term) self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, 1::2] = paddle.cos(position * div_term) self.pe[:, :, 1::2] = paddle.cos(position * div_term)
self.pe = self.pe.unsqueeze(0) #[1, T, D]
def forward(self, x: paddle.Tensor, def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:

@ -264,15 +264,15 @@ class BaseEncoder(nn.Layer):
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim r_cnn_cache.append(new_cnn_cache) # add elayer dim
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# r_att_cache (elayers, head, T, d_k*2) # r_att_cache (elayers, head, T, d_k*2)
# r_cnn_cache elayers, B=1, hidden-dim, cache_t2) # r_cnn_cache (elayers, B=1, hidden-dim, cache_t2)
r_att_cache = paddle.concat(r_att_cache, axis=0) r_att_cache = paddle.concat(r_att_cache, axis=0)
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) r_cnn_cache = paddle.stack(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(

Loading…
Cancel
Save