[Fix] transpose use numpy (#3933)

pull/3935/head
megemini 3 weeks ago committed by GitHub
parent c0fafd0647
commit ff539ef007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1267,7 +1267,7 @@ class TransposeLast(nn.Layer):
def forward(self, x): def forward(self, x):
if self.deconstruct_idx is not None: if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx] x = x[self.deconstruct_idx]
trans_dim = paddle.arange(x.dim()) trans_dim = np.arange(x.dim())
trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1] trans_dim[-1], trans_dim[-2] = trans_dim[-2], trans_dim[-1]
return x.transpose(trans_dim) return x.transpose(trans_dim)

Loading…
Cancel
Save