diff --git a/paddlespeech/cls/exps/panns/deploy/predict.py b/paddlespeech/cls/exps/panns/deploy/predict.py index 464685331..a6b735335 100644 --- a/paddlespeech/cls/exps/panns/deploy/predict.py +++ b/paddlespeech/cls/exps/panns/deploy/predict.py @@ -58,7 +58,7 @@ def extract_features(files: str, **kwargs): feature_extractor = LogMelSpectrogram(sr, **kwargs) feat = feature_extractor(paddle.to_tensor(waveforms[i])) - feat = paddle.transpose(feat, perm=[1, 0]) + feat = paddle.transpose(feat, perm=[1, 0]).unsqueeze(0) feats.append(feat) return np.stack(feats, axis=0)