fix Tensor.numpy()[0] to float(Tensor) to adapt 0D (#2884)

pull/2896/head
Zhou Wei 3 years ago committed by GitHub
parent 089c060756
commit 16d84367c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -121,7 +121,7 @@ if __name__ == "__main__":
optimizer.clear_grad() optimizer.clear_grad()
# Calculate loss # Calculate loss
avg_loss += loss.numpy()[0] avg_loss += float(loss)
# Calculate metrics # Calculate metrics
preds = paddle.argmax(logits, axis=1) preds = paddle.argmax(logits, axis=1)

Loading…
Cancel
Save