[Fix] transpose use numpy

pull/3933/head
megemini 10 months ago
parent 67ae7c8dd2
commit 39ec5dc5e7

@ -1267,7 +1267,7 @@ class TransposeLast(nn.Layer):
def forward(self, x): def forward(self, x):
if self.deconstruct_idx is not None: if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx] x = x[self.deconstruct_idx]
trans_dim = paddle.arange(x.dim()) trans_dim = np.arange(x.dim())
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1] trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
return x.transpose(trans_dim) return x.transpose(trans_dim)

Loading…
Cancel
Save