|
|
@ -498,6 +498,8 @@
|
|
|
|
" waveforms, labels = batch\n",
|
|
|
|
" waveforms, labels = batch\n",
|
|
|
|
" feats = feature_extractor(waveforms)\n",
|
|
|
|
" feats = feature_extractor(waveforms)\n",
|
|
|
|
" feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n",
|
|
|
|
" feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n",
|
|
|
|
|
|
|
|
" if feats.dim() == 3:\n",
|
|
|
|
|
|
|
|
" feats = feats.unsqueeze(1)\n",
|
|
|
|
" logits = model(feats)\n",
|
|
|
|
" logits = model(feats)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
" loss = criterion(logits, labels)\n",
|
|
|
|
" loss = criterion(logits, labels)\n",
|
|
|
@ -541,7 +543,9 @@
|
|
|
|
" waveforms, labels = batch\n",
|
|
|
|
" waveforms, labels = batch\n",
|
|
|
|
" feats = feature_extractor(waveforms)\n",
|
|
|
|
" feats = feature_extractor(waveforms)\n",
|
|
|
|
" feats = paddle.transpose(feats, [0, 2, 1])\n",
|
|
|
|
" feats = paddle.transpose(feats, [0, 2, 1])\n",
|
|
|
|
" \n",
|
|
|
|
" if feats.dim() == 3:\n",
|
|
|
|
|
|
|
|
" feats = feats.unsqueeze(1)\n",
|
|
|
|
|
|
|
|
"\n",
|
|
|
|
" logits = model(feats)\n",
|
|
|
|
" logits = model(feats)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
" preds = paddle.argmax(logits, axis=1)\n",
|
|
|
|
" preds = paddle.argmax(logits, axis=1)\n",
|
|
|
@ -576,6 +580,8 @@
|
|
|
|
"feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))\n",
|
|
|
|
"feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))\n",
|
|
|
|
"feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n",
|
|
|
|
"feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n",
|
|
|
|
"print(feats.shape)\n",
|
|
|
|
"print(feats.shape)\n",
|
|
|
|
|
|
|
|
"if feats.dim() == 3:\n",
|
|
|
|
|
|
|
|
" feats = feats.unsqueeze(1)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
"logits = model(feats)\n",
|
|
|
|
"logits = model(feats)\n",
|
|
|
|
"probs = nn.functional.softmax(logits, axis=1).numpy()\n",
|
|
|
|
"probs = nn.functional.softmax(logits, axis=1).numpy()\n",
|
|
|
|