diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index 93883c94d..b246a6459 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -177,8 +177,9 @@ def th_accuracy(pad_outputs: paddle.Tensor, Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view(pad_targets.shape[0], pad_targets.shape[1], - pad_outputs.shape[1]).argmax(2) + pad_pred = pad_outputs.reshape( + [pad_targets.shape[0], pad_targets.shape[1], + pad_outputs.shape[1]]).argmax(2) mask = pad_targets != ignore_label #TODO(Hui Zhang): sum not support bool type # numerator = paddle.sum( @@ -248,7 +249,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor, # >>> tensor([[ 2, 1, 0], # >>> [ 2, 1, 0], # >>> [ 0, -1, -2]]) - index = index * seq_mask + index = index * seq_mask.astype(index.dtype) # >>> index # >>> tensor([[2, 1, 0],