refactor reshape

pull/2425/head
Hui Zhang 3 years ago
parent 6de81d74d9
commit c2c8a662b1

@ -110,12 +110,10 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = x.shape[1]
assert offset + x.shape[
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
#TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
pos_emb = self.pe[:, offset:offset + x.shape[1]]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
@ -164,6 +162,5 @@ class RelPositionalEncoding(PositionalEncoding):
1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format(
offset, x.shape[1], self.max_len)
x = x * self.xscale
#TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)

@ -139,8 +139,8 @@ class Conv2dSubsampling4(Conv2dSubsampling):
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
b, c, t, f = x.shape
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f]))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]

Loading…
Cancel
Save