Update cls tutorial. (#1221)

pull/1228/head
KP 4 years ago committed by GitHub
parent 5692b0ff04
commit 425b085f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,9 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"<a href=\"https://github.com/PaddlePaddle/PaddleSpeech\"><img style=\"position: absolute; z-index: 999; top: 0; right: 0; border: 0; width: 128px; height: 128px;\" src=\"https://nosir.github.io/cleave.js/images/right-graphite@2x.png\" alt=\"Fork me on GitHub\"></a>\n", "<a href=\"https://github.com/PaddlePaddle/PaddleSpeech\"><img style=\"position: absolute; z-index: 999; top: 0; right: 0; border: 0; width: 128px; height: 128px;\" src=\"https://nosir.github.io/cleave.js/images/right-graphite@2x.png\" alt=\"Fork me on GitHub\"></a>\n",
"\n", "\n",
@ -32,9 +30,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"%%HTML\n", "%%HTML\n",
@ -45,9 +41,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"# 2. 音频和特征提取" "# 2. 音频和特征提取"
] ]
@ -55,9 +49,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# 环境准备安装paddlespeech和paddleaudio\n", "# 环境准备安装paddlespeech和paddleaudio\n",
@ -67,9 +59,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import warnings\n", "import warnings\n",
@ -82,9 +72,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"\n", "\n",
"\n", "\n",
@ -98,9 +86,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# 获取示例音频\n", "# 获取示例音频\n",
@ -111,9 +97,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from paddleaudio import load\n", "from paddleaudio import load\n",
@ -130,9 +114,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"!paddlespeech cls --input ./dog.wav" "!paddlespeech cls --input ./dog.wav"
@ -140,9 +122,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"## 2.2 音频特征提取\n", "## 2.2 音频特征提取\n",
"\n", "\n",
@ -162,21 +142,20 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import paddle\n", "import paddle\n",
"import numpy as np\n", "import numpy as np\n",
"\n", "\n",
"data, sr = load(file='./dog.wav', sr=32000, mono=True, dtype='float32')\n",
"x = paddle.to_tensor(data)\n", "x = paddle.to_tensor(data)\n",
"n_fft = 1024\n", "n_fft = 1024\n",
"win_length = 1024\n", "win_length = 1024\n",
"hop_length = 512\n", "hop_length = 320\n",
"\n", "\n",
"# [D, T]\n", "# [D, T]\n",
"spectrogram = paddle.signal.stft(x, n_fft=1024, win_length=1024, hop_length=512, onesided=True) \n", "spectrogram = paddle.signal.stft(x, n_fft=n_fft, win_length=win_length, hop_length=hop_length, onesided=True) \n",
"print('spectrogram.shape: {}'.format(spectrogram.shape))\n", "print('spectrogram.shape: {}'.format(spectrogram.shape))\n",
"print('spectrogram.dtype: {}'.format(spectrogram.dtype))\n", "print('spectrogram.dtype: {}'.format(spectrogram.dtype))\n",
"\n", "\n",
@ -190,9 +169,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"### 2.2.2 LogFBank\n", "### 2.2.2 LogFBank\n",
"\n", "\n",
@ -220,13 +197,15 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from paddleaudio.features import LogMelSpectrogram\n", "from paddleaudio.features import LogMelSpectrogram\n",
"\n", "\n",
"f_min=50.0\n",
"f_max=14000.0\n",
"n_mels=64\n",
"\n",
"# - sr: 音频文件的采样率。\n", "# - sr: 音频文件的采样率。\n",
"# - n_fft: FFT样本点个数。\n", "# - n_fft: FFT样本点个数。\n",
"# - hop_length: 音频帧之间的间隔。\n", "# - hop_length: 音频帧之间的间隔。\n",
@ -239,7 +218,9 @@
" hop_length=hop_length, \n", " hop_length=hop_length, \n",
" win_length=win_length, \n", " win_length=win_length, \n",
" window='hann', \n", " window='hann', \n",
" n_mels=64)\n", " f_min=f_min,\n",
" f_max=f_max,\n",
" n_mels=n_mels)\n",
"\n", "\n",
"x = paddle.to_tensor(data).unsqueeze(0) # [B, L]\n", "x = paddle.to_tensor(data).unsqueeze(0) # [B, L]\n",
"log_fbank = feature_extractor2(x) # [B, D, T]\n", "log_fbank = feature_extractor2(x) # [B, D, T]\n",
@ -253,9 +234,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"## 2.3 声音分类方法\n", "## 2.3 声音分类方法\n",
"\n", "\n",
@ -272,9 +251,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"### 2.3.2 深度学习方法\n", "### 2.3.2 深度学习方法\n",
"传统机器学习方法可以捕捉声音特征的差异(例如男声和女声的声音在音高上往往差异较大)并实现分类任务。\n", "传统机器学习方法可以捕捉声音特征的差异(例如男声和女声的声音在音高上往往差异较大)并实现分类任务。\n",
@ -288,9 +265,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"### 2.3.3 Pretrain + Finetune\n", "### 2.3.3 Pretrain + Finetune\n",
"\n", "\n",
@ -315,9 +290,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"# 3. 实践:环境声音分类\n", "# 3. 实践:环境声音分类\n",
"\n", "\n",
@ -361,22 +334,18 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from paddleaudio.datasets import ESC50\n", "from paddleaudio.datasets import ESC50\n",
"\n", "\n",
"train_ds = ESC50(mode='train')\n", "train_ds = ESC50(mode='train', sample_rate=sr)\n",
"dev_ds = ESC50(mode='dev')" "dev_ds = ESC50(mode='dev', sample_rate=sr)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"### 3.1.2 特征提取\n", "### 3.1.2 特征提取\n",
"通过下列代码,用 `paddleaudio.features.LogMelSpectrogram` 初始化一个音频特征提取器,在训练过程中实时提取音频的 LogFBank 特征,其中主要的参数如下: " "通过下列代码,用 `paddleaudio.features.LogMelSpectrogram` 初始化一个音频特征提取器,在训练过程中实时提取音频的 LogFBank 特征,其中主要的参数如下: "
@ -385,19 +354,23 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"feature_extractor = LogMelSpectrogram(sr=44100, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window='hann', n_mels=64)" "feature_extractor = LogMelSpectrogram(\n",
" sr=sr, \n",
" n_fft=n_fft, \n",
" hop_length=hop_length, \n",
" win_length=win_length, \n",
" window='hann', \n",
" f_min=f_min,\n",
" f_max=f_max,\n",
" n_mels=n_mels)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"## 3.2 模型\n", "## 3.2 模型\n",
"\n", "\n",
@ -409,9 +382,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from paddlespeech.cls.models import cnn14\n", "from paddlespeech.cls.models import cnn14\n",
@ -420,9 +391,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"### 3.2.2 构建分类模型\n", "### 3.2.2 构建分类模型\n",
"\n", "\n",
@ -432,9 +401,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import paddle.nn as nn\n", "import paddle.nn as nn\n",
@ -461,18 +428,14 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"## 3.3 Finetune" "## 3.3 Finetune"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"1. 创建 DataLoader " "1. 创建 DataLoader "
] ]
@ -480,9 +443,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"batch_size = 16\n", "batch_size = 16\n",
@ -492,9 +453,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"2. 定义优化器和 Loss" "2. 定义优化器和 Loss"
] ]
@ -502,9 +461,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"optimizer = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters())\n", "optimizer = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters())\n",
@ -513,19 +470,15 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"3. 启动模型训练 " "3. 启动模型训练 "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from paddleaudio.utils import logger\n", "from paddleaudio.utils import logger\n",
@ -603,9 +556,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"## 3.4 音频预测\n", "## 3.4 音频预测\n",
"\n", "\n",
@ -615,16 +566,13 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": false
},
"outputs": [], "outputs": [],
"source": [ "source": [
"top_k = 10\n", "top_k = 10\n",
"wav_file = './dog.wav'\n", "wav_file = './dog.wav'\n",
"\n", "\n",
"waveform, sr = load(wav_file)\n", "waveform, _ = load(wav_file, sr)\n",
"feature_extractor = LogMelSpectrogram(sr=sr, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window='hann', n_mels=64)\n",
"feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))\n", "feats = feature_extractor(paddle.to_tensor(paddle.to_tensor(waveform).unsqueeze(0)))\n",
"feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n", "feats = paddle.transpose(feats, [0, 2, 1]) # [B, N, T] -> [B, T, N]\n",
"print(feats.shape)\n", "print(feats.shape)\n",
@ -635,16 +583,14 @@
"sorted_indices = probs[0].argsort()\n", "sorted_indices = probs[0].argsort()\n",
"\n", "\n",
"msg = f'[{wav_file}]\\n'\n", "msg = f'[{wav_file}]\\n'\n",
"for idx in sorted_indices[-top_k:]:\n", "for idx in sorted_indices[-1:-top_k-1:-1]:\n",
" msg += f'{ESC50.label_list[idx]}: {probs[0][idx]:.5f}\\n'\n", " msg += f'{ESC50.label_list[idx]}: {probs[0][idx]:.5f}\\n'\n",
"print(msg)" "print(msg)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {},
"collapsed": false
},
"source": [ "source": [
"# 4. 作业\n", "# 4. 作业\n",
"1. 使用开发模式安装 [PaddleSpeech](https://github.com/PaddlePaddle/PaddleSpeech) \n", "1. 使用开发模式安装 [PaddleSpeech](https://github.com/PaddlePaddle/PaddleSpeech) \n",
@ -653,6 +599,7 @@
"1. 在 [MusicSpeech](http://marsyas.info/downloads/datasets.html) 数据集上完成 music/speech 二分类。 \n", "1. 在 [MusicSpeech](http://marsyas.info/downloads/datasets.html) 数据集上完成 music/speech 二分类。 \n",
"2. 在 [GTZAN Genre Collection](http://marsyas.info/downloads/datasets.html) 音乐分类数据集上利用 PANNs 预训练模型实现音乐类别十分类。\n", "2. 在 [GTZAN Genre Collection](http://marsyas.info/downloads/datasets.html) 音乐分类数据集上利用 PANNs 预训练模型实现音乐类别十分类。\n",
"\n", "\n",
"关于如何自定义分类数据集,请参考文档 [PaddleSpeech/docs/source/cls/custom_dataset.md](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/cls/custom_dataset.md)\n",
"\n", "\n",
"# 5. 关注 PaddleSpeech\n", "# 5. 关注 PaddleSpeech\n",
"\n", "\n",
@ -681,9 +628,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "py37",
"language": "python", "language": "python",
"name": "py35-paddle1.2.0" "name": "py37"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -695,7 +642,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.4" "version": "3.7.7"
} }
}, },
"nbformat": 4, "nbformat": 4,

Loading…
Cancel
Save