diff --git a/docs/source/cls/custom_dataset.md b/docs/source/cls/custom_dataset.md index 7482d5edf..813abafa4 100644 --- a/docs/source/cls/custom_dataset.md +++ b/docs/source/cls/custom_dataset.md @@ -98,6 +98,8 @@ for epoch in range(1, epochs + 1): # Need a padding when lengths of waveforms differ in a batch. feats = feature_extractor(waveforms) feats = paddle.transpose(feats, [0, 2, 1]) + if feats.dim() == 3: + feats = feats.unsqueeze(1) logits = model(feats) loss = criterion(logits, labels) loss.backward() diff --git a/docs/tutorial/cls/cls_tutorial.ipynb b/docs/tutorial/cls/cls_tutorial.ipynb index 3cee64991..5941ed6c2 100644 --- a/docs/tutorial/cls/cls_tutorial.ipynb +++ b/docs/tutorial/cls/cls_tutorial.ipynb @@ -498,6 +498,8 @@ " waveforms, labels = batch\n", " feats = feature_extractor(waveforms)\n", " feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n", + " if feats.dim() == 3:\n", + " feats = feats.unsqueeze(1)\n", " logits = model(feats)\n", "\n", " loss = criterion(logits, labels)\n", @@ -541,7 +543,9 @@ " waveforms, labels = batch\n", " feats = feature_extractor(waveforms)\n", " feats = paddle.transpose(feats, [0, 2, 1])\n", - " \n", + " if feats.dim() == 3:\n", + " feats = feats.unsqueeze(1)\n", + "\n", " logits = model(feats)\n", "\n", " preds = paddle.argmax(logits, axis=1)\n", @@ -576,6 +580,8 @@ "feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))\n", "feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n", "print(feats.shape)\n", + "if feats.dim() == 3:\n", + " feats = feats.unsqueeze(1)\n", "\n", "logits = model(feats)\n", "probs = nn.functional.softmax(logits, axis=1).numpy()\n", diff --git a/examples/tess/cls0/local/train.py b/examples/tess/cls0/local/train.py index f023a37b7..8576aedf2 100644 --- a/examples/tess/cls0/local/train.py +++ b/examples/tess/cls0/local/train.py @@ -109,6 +109,8 @@ if __name__ == "__main__": num_samples = 0 for batch_idx, batch in enumerate(train_loader): feats, labels, length = batch # feats-->(N, length, n_mels) + if feats.dim() == 3: + feats = feats.unsqueeze(1) logits = model(feats) @@ -170,6 +172,9 @@ if __name__ == "__main__": with logger.processing('Evaluation on validation dataset'): for batch_idx, batch in enumerate(dev_loader): feats, labels, length = batch + if feats.dim() == 3: + feats = feats.unsqueeze(1) + logits = model(feats) preds = paddle.argmax(logits, axis=1) diff --git a/paddlespeech/cls/exps/panns/export_model.py b/paddlespeech/cls/exps/panns/export_model.py index 63b22981a..e84418bad 100644 --- a/paddlespeech/cls/exps/panns/export_model.py +++ b/paddlespeech/cls/exps/panns/export_model.py @@ -38,8 +38,9 @@ if __name__ == '__main__': model, input_spec=[ paddle.static.InputSpec( - shape=[None, None, 64], dtype=paddle.float32) - ]) + shape=[None, 1, None, 64], dtype=paddle.float32) + ], + full_graph=True) # Save in static graph model. paddle.jit.save(model, os.path.join(args.output_dir, "inference")) diff --git a/paddlespeech/cls/exps/panns/predict.py b/paddlespeech/cls/exps/panns/predict.py index 4681e4dc9..fd6588df4 100644 --- a/paddlespeech/cls/exps/panns/predict.py +++ b/paddlespeech/cls/exps/panns/predict.py @@ -62,6 +62,9 @@ if __name__ == '__main__': model.eval() feat = extract_features(predicting_conf['audio_file'], **feat_conf) + if feat.dim() == 3: + feat = feat.unsqueeze(1) + logits = model(feat) probs = F.softmax(logits, axis=1).numpy() diff --git a/paddlespeech/cls/exps/panns/train.py b/paddlespeech/cls/exps/panns/train.py index b768919be..da4c176db 100644 --- a/paddlespeech/cls/exps/panns/train.py +++ b/paddlespeech/cls/exps/panns/train.py @@ -89,6 +89,8 @@ if __name__ == "__main__": waveforms ) # Need a padding when lengths of waveforms differ in a batch. feats = paddle.transpose(feats, [0, 2, 1]) # To [N, length, n_mels] + if feats.dim() == 3: + feats = feats.unsqueeze(1) logits = model(feats) @@ -150,6 +152,8 @@ if __name__ == "__main__": waveforms, labels = batch feats = feature_extractor(waveforms) feats = paddle.transpose(feats, [0, 2, 1]) + if feats.dim() == 3: + feats = feats.unsqueeze(1) logits = model(feats) diff --git a/paddlespeech/cls/models/panns/classifier.py b/paddlespeech/cls/models/panns/classifier.py index df64158ff..6510820fb 100644 --- a/paddlespeech/cls/models/panns/classifier.py +++ b/paddlespeech/cls/models/panns/classifier.py @@ -28,7 +28,9 @@ class SoundClassifier(nn.Layer): def forward(self, x): # x: (batch_size, num_frames, num_melbins) -> (batch_size, 1, num_frames, num_melbins) - x = x.unsqueeze(1) + if x.dim() == 3: + x = x.unsqueeze(1) + x = self.backbone(x) x = self.dropout(x) logits = self.fc(x)