diff --git a/README.md b/README.md index 1144d3ab5..a90498293 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision 2021.12.14: We would like to have an online courses to introduce basics and research of speech, as well as code practice with `paddlespeech`. Please pay attention to our [Calendar](https://www.paddlepaddle.org.cn/live). ---> - 👏🏻 2022.03.28: PaddleSpeech Server is available for Audio Classification, Automatic Speech Recognition and Text-to-Speech. -- 👏🏻 2022.03.28: PaddleSpeech CLI is available for Speaker Verfication. +- 👏🏻 2022.03.28: PaddleSpeech CLI is available for Speaker Verification. - 🤗 2021.12.14: Our PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available! - 👏🏻 2021.12.10: PaddleSpeech CLI is available for Audio Classification, Automatic Speech Recognition, Speech Translation (English to Chinese) and Text-to-Speech. diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 9a423e03e..48ceaf843 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -80,7 +80,7 @@ PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https Model Type | Dataset| Example Link | Pretrained Models | Static Models :-------------:| :------------:| :-----: | :-----: | :-----: -PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_1.tar.gz) | - +PANN | VoxCeleb| [voxceleb_ecapatdnn](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0) | [ecapatdnn.tar.gz](https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz) | - ## Punctuation Restoration Models Model Type | Dataset| Example Link | Pretrained Models diff --git a/examples/esc50/README.md b/examples/esc50/README.md index 911a72ad7..9eab95d26 100644 --- a/examples/esc50/README.md +++ b/examples/esc50/README.md @@ -4,7 +4,7 @@ 对于声音分类任务,传统机器学习的一个常用做法是首先人工提取音频的时域和频域的多种特征并做特征选择、组合、变换等,然后基于SVM或决策树进行分类。而端到端的深度学习则通常利用深度网络如RNN,CNN等直接对声间波形(waveform)或时频特征(time-frequency)进行特征学习(representation learning)和分类预测。 -在IEEE ICASSP 2017 大会上,谷歌开放了一个大规模的音频数据集[Audioset](https://research.google.com/audioset/)。该数据集包含了 632 类的音频类别以及 2,084,320 条人工标记的每段 10 秒长度的声音剪辑片段(来源于YouTube视频)。目前该数据集已经有210万个已标注的视频数据,5800小时的音频数据,经过标记的声音样本的标签类别为527。 +在IEEE ICASSP 2017 大会上,谷歌开放了一个大规模的音频数据集[Audioset](https://research.google.com/audioset/)。该数据集包含了 632 类的音频类别以及 2,084,320 条人工标记的每段 **10 秒**长度的声音剪辑片段(来源于YouTube视频)。目前该数据集已经有 210万 个已标注的视频数据,5800 小时的音频数据,经过标记的声音样本的标签类别为 527。 `PANNs`([PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition](https://arxiv.org/pdf/1912.10211.pdf))是基于Audioset数据集训练的声音分类/识别的模型。经过预训练后,模型可以用于提取音频的embbedding。本示例将使用`PANNs`的预训练模型Finetune完成声音分类的任务。 @@ -12,14 +12,14 @@ ## 模型简介 PaddleAudio提供了PANNs的CNN14、CNN10和CNN6的预训练模型,可供用户选择使用: -- CNN14: 该模型主要包含12个卷积层和2个全连接层,模型参数的数量为79.6M,embbedding维度是2048。 -- CNN10: 该模型主要包含8个卷积层和2个全连接层,模型参数的数量为4.9M,embbedding维度是512。 -- CNN6: 该模型主要包含4个卷积层和2个全连接层,模型参数的数量为4.5M,embbedding维度是512。 +- CNN14: 该模型主要包含12个卷积层和2个全连接层,模型参数的数量为 79.6M,embbedding维度是 2048。 +- CNN10: 该模型主要包含8个卷积层和2个全连接层,模型参数的数量为 4.9M,embbedding维度是 512。 +- CNN6: 该模型主要包含4个卷积层和2个全连接层,模型参数的数量为 4.5M,embbedding维度是 512。 ## 数据集 -[ESC-50: Dataset for Environmental Sound Classification](https://github.com/karolpiczak/ESC-50) 是一个包含有 2000 个带标签的环境声音样本,音频样本采样率为 44,100Hz 的单通道音频文件,所有样本根据标签被划分为 50 个类别,每个类别有 40 个样本。 +[ESC-50: Dataset for Environmental Sound Classification](https://github.com/karolpiczak/ESC-50) 是一个包含有 2000 个带标签的时长为 **5 秒**的环境声音样本,音频样本采样率为 44,100Hz 的单通道音频文件,所有样本根据标签被划分为 50 个类别,每个类别有 40 个样本。 ## 模型指标 @@ -43,13 +43,13 @@ $ CUDA_VISIBLE_DEVICES=0 ./run.sh 1 conf/panns.yaml ``` 训练的参数可在 `conf/panns.yaml` 的 `training` 中配置,其中: -- `epochs`: 训练轮次,默认为50。 +- `epochs`: 训练轮次,默认为 50。 - `learning_rate`: Fine-tune的学习率;默认为5e-5。 -- `batch_size`: 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为16。 +- `batch_size`: 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为 16。 - `num_workers`: Dataloader获取数据的子进程数。默认为0,加载数据的流程在主进程执行。 - `checkpoint_dir`: 模型参数文件和optimizer参数文件的保存目录,默认为`./checkpoint`。 -- `save_freq`: 训练过程中的模型保存频率,默认为10。 -- `log_freq`: 训练过程中的信息打印频率,默认为10。 +- `save_freq`: 训练过程中的模型保存频率,默认为 10。 +- `log_freq`: 训练过程中的信息打印频率,默认为 10。 示例代码中使用的预训练模型为`CNN14`,如果想更换为其他预训练模型,可通过修改 `conf/panns.yaml` 的 `model` 中配置: ```yaml @@ -76,7 +76,7 @@ $ CUDA_VISIBLE_DEVICES=0 ./run.sh 2 conf/panns.yaml 训练的参数可在 `conf/panns.yaml` 的 `predicting` 中配置,其中: - `audio_file`: 指定预测的音频文件。 -- `top_k`: 预测显示的top k标签的得分,默认为1。 +- `top_k`: 预测显示的top k标签的得分,默认为 1。 - `checkpoint`: 模型参数checkpoint文件。 输出的预测结果如下: diff --git a/examples/voxceleb/sv0/RESULT.md b/examples/voxceleb/sv0/RESULT.md index fcf5a7b36..3a3f67d09 100644 --- a/examples/voxceleb/sv0/RESULT.md +++ b/examples/voxceleb/sv0/RESULT.md @@ -4,4 +4,4 @@ | Model | Number of Params | Release | Config | dim | Test set | Cosine | Cosine + S-Norm | | --- | --- | --- | --- | --- | --- | --- | ---- | -| ECAPA-TDNN | 85M | 0.1.2 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 | +| ECAPA-TDNN | 85M | 0.2.0 | conf/ecapa_tdnn.yaml |192 | test | 1.02 | 0.95 | diff --git a/paddleaudio/paddleaudio/metric/__init__.py b/paddleaudio/paddleaudio/metric/__init__.py index 8e5ca9f75..d2b3a1360 100644 --- a/paddleaudio/paddleaudio/metric/__init__.py +++ b/paddleaudio/paddleaudio/metric/__init__.py @@ -14,4 +14,3 @@ from .dtw import dtw_distance from .eer import compute_eer from .eer import compute_minDCF -from .mcd import mcd_distance diff --git a/paddleaudio/paddleaudio/metric/mcd.py b/paddleaudio/paddleaudio/metric/mcd.py deleted file mode 100644 index 63a25fc23..000000000 --- a/paddleaudio/paddleaudio/metric/mcd.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable - -import mcd.metrics_fast as mt -import numpy as np -from mcd import dtw - -__all__ = [ - 'mcd_distance', -] - - -def mcd_distance(xs: np.ndarray, - ys: np.ndarray, - cost_fn: Callable=mt.logSpecDbDist) -> float: - """Mel cepstral distortion (MCD), dtw distance. - - Dynamic Time Warping. - Uses dynamic programming to compute: - - Examples: - .. code-block:: python - - wps[i, j] = cost_fn(xs[i], ys[j]) + min( - wps[i-1, j ], // vertical / insertion / expansion - wps[i , j-1], // horizontal / deletion / compression - wps[i-1, j-1]) // diagonal / match - - dtw = sqrt(wps[-1, -1]) - - Cost Function: - Examples: - .. code-block:: python - - logSpecDbConst = 10.0 / math.log(10.0) * math.sqrt(2.0) - - def logSpecDbDist(x, y): - diff = x - y - return logSpecDbConst * math.sqrt(np.inner(diff, diff)) - - Args: - xs (np.ndarray): ref sequence, [T,D] - ys (np.ndarray): hyp sequence, [T,D] - cost_fn (Callable, optional): Cost function. Defaults to mt.logSpecDbDist. - - Returns: - float: dtw distance - """ - - min_cost, path = dtw.dtw(xs, ys, cost_fn) - return min_cost diff --git a/paddleaudio/setup.py b/paddleaudio/setup.py index e08b88a3b..c92e5c73f 100644 --- a/paddleaudio/setup.py +++ b/paddleaudio/setup.py @@ -19,7 +19,7 @@ from setuptools.command.install import install from setuptools.command.test import test # set the version here -VERSION = '0.2.0' +VERSION = '0.2.1' # Inspired by the example at https://pytest.org/latest/goodpractises.html @@ -83,9 +83,8 @@ setuptools.setup( python_requires='>=3.6', install_requires=[ 'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2', - 'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'mcd >= 0.4', - 'pathos' - ], + 'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos' + ], extras_require={ 'test': [ 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1', diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 9904b5eda..68e832ac7 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -43,7 +43,7 @@ pretrained_models = { # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" "ecapatdnn_voxceleb12-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_1_2.tar.gz', + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', 'md5': 'cc33023c54ab346cd318408f43fcaf95', 'cfg_path': diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc index ca76d85c7..2d75bb5df 100644 --- a/speechx/examples/feat/linear_spectrogram_main.cc +++ b/speechx/examples/feat/linear_spectrogram_main.cc @@ -181,6 +181,10 @@ int main(int argc, char* argv[]) { ppspeech::LinearSpectrogramOptions opt; opt.frame_opts.frame_length_ms = 20; opt.frame_opts.frame_shift_ms = 10; + opt.frame_opts.dither = 0.0; + opt.frame_opts.remove_dc_offset = false; + opt.frame_opts.window_type = "hanning"; + opt.frame_opts.preemph_coeff = 0.0; LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms; LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms; diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index 827b8eccf..d6ae3d012 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -14,6 +14,8 @@ #include "frontend/audio/linear_spectrogram.h" #include "kaldi/base/kaldi-math.h" +#include "kaldi/feat/feature-common.h" +#include "kaldi/feat/feature-functions.h" #include "kaldi/matrix/matrix-functions.h" namespace ppspeech { @@ -21,30 +23,23 @@ namespace ppspeech { using kaldi::int32; using kaldi::BaseFloat; using kaldi::Vector; +using kaldi::SubVector; using kaldi::VectorBase; using kaldi::Matrix; using std::vector; LinearSpectrogram::LinearSpectrogram( const LinearSpectrogramOptions& opts, - std::unique_ptr base_extractor) { - opts_ = opts; + std::unique_ptr base_extractor) + : opts_(opts), feature_window_funtion_(opts.frame_opts) { base_extractor_ = std::move(base_extractor); int32 window_size = opts.frame_opts.WindowSize(); int32 window_shift = opts.frame_opts.WindowShift(); - fft_points_ = window_size; + dim_ = window_size / 2 + 1; chunk_sample_size_ = static_cast(opts.streaming_chunk * opts.frame_opts.samp_freq); - hanning_window_.resize(window_size); - - double a = M_2PI / (window_size - 1); - hanning_window_energy_ = 0; - for (int i = 0; i < window_size; ++i) { - hanning_window_[i] = 0.5 - 0.5 * cos(a * i); - hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; - } - - dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz + hanning_window_energy_ = kaldi::VecVec(feature_window_funtion_.window, + feature_window_funtion_.window); } void LinearSpectrogram::Accept(const VectorBase& inputs) { @@ -56,99 +51,57 @@ bool LinearSpectrogram::Read(Vector* feats) { bool flag = base_extractor_->Read(&input_feats); if (flag == false || input_feats.Dim() == 0) return false; - vector input_feats_vec(input_feats.Dim()); - std::memcpy(input_feats_vec.data(), - input_feats.Data(), - input_feats.Dim() * sizeof(BaseFloat)); - vector> result; - Compute(input_feats_vec, result); - int32 feat_size = 0; - if (result.size() != 0) { - feat_size = result.size() * result[0].size(); - } - feats->Resize(feat_size); - // todo refactor (SimleGoat) - for (size_t idx = 0; idx < feat_size; ++idx) { - (*feats)(idx) = result[idx / dim_][idx % dim_]; - } - return true; -} - -void LinearSpectrogram::Hanning(vector* data) const { - CHECK_GE(data->size(), hanning_window_.size()); - - for (size_t i = 0; i < hanning_window_.size(); ++i) { - data->at(i) *= hanning_window_[i]; - } -} - -bool LinearSpectrogram::NumpyFft(vector* v, - vector* real, - vector* img) const { - Vector v_tmp; - v_tmp.Resize(v->size()); - std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat) * (v->size())); - RealFft(&v_tmp, true); - v->resize(v_tmp.Dim()); - std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat) * (v->size())); - - real->push_back(v->at(0)); - img->push_back(0); - for (int i = 1; i < v->size() / 2; i++) { - real->push_back(v->at(2 * i)); - img->push_back(v->at(2 * i + 1)); - } - real->push_back(v->at(1)); - img->push_back(0); - + int32 feat_len = input_feats.Dim(); + int32 left_len = reminded_wav_.Dim(); + Vector waves(feat_len + left_len); + waves.Range(0, left_len).CopyFromVec(reminded_wav_); + waves.Range(left_len, feat_len).CopyFromVec(input_feats); + Compute(waves, feats); + int32 frame_shift = opts_.frame_opts.WindowShift(); + int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); + int32 left_samples = waves.Dim() - frame_shift * num_frames; + reminded_wav_.Resize(left_samples); + reminded_wav_.CopyFromVec( + waves.Range(frame_shift * num_frames, left_samples)); return true; } // Compute spectrogram feat -// todo: refactor later (SmileGoat) -bool LinearSpectrogram::Compute(const vector& waves, - vector>& feats) { - int num_samples = waves.size(); - const int& frame_length = opts_.frame_opts.WindowSize(); - const int& sample_rate = opts_.frame_opts.samp_freq; - const int& frame_shift = opts_.frame_opts.WindowShift(); - const int& fft_points = fft_points_; - const float scale = hanning_window_energy_ * sample_rate; +bool LinearSpectrogram::Compute(const Vector& waves, + Vector* feats) { + int32 num_samples = waves.Dim(); + int32 frame_length = opts_.frame_opts.WindowSize(); + int32 sample_rate = opts_.frame_opts.samp_freq; + BaseFloat scale = 2.0 / (hanning_window_energy_ * sample_rate); if (num_samples < frame_length) { return true; } - int num_frames = 1 + ((num_samples - frame_length) / frame_shift); - feats.resize(num_frames); - vector fft_real((fft_points_ / 2 + 1), 0); - vector fft_img((fft_points_ / 2 + 1), 0); - vector v(frame_length, 0); - vector power((fft_points / 2 + 1)); - - for (int i = 0; i < num_frames; ++i) { - vector data(waves.data() + i * frame_shift, - waves.data() + i * frame_shift + frame_length); - Hanning(&data); - fft_img.clear(); - fft_real.clear(); - v.assign(data.begin(), data.end()); - NumpyFft(&v, &fft_real, &fft_img); - - feats[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz - for (int j = 0; j < (fft_points / 2 + 1); ++j) { - power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; - feats[i][j] = power[j]; - - if (j == 0 || j == feats[0].size() - 1) { - feats[i][j] /= scale; - } else { - feats[i][j] *= (2.0 / scale); - } - - // log added eps=1e-14 - feats[i][j] = std::log(feats[i][j] + 1e-14); - } + int32 num_frames = kaldi::NumFrames(num_samples, opts_.frame_opts); + feats->Resize(num_frames * dim_); + Vector window; + + for (int frame_idx = 0; frame_idx < num_frames; ++frame_idx) { + kaldi::ExtractWindow(0, + waves, + frame_idx, + opts_.frame_opts, + feature_window_funtion_, + &window, + NULL); + + SubVector output_row(feats->Data() + frame_idx * dim_, dim_); + window.Resize(frame_length, kaldi::kCopyData); + RealFft(&window, true); + kaldi::ComputePowerSpectrum(&window); + SubVector power_spectrum(window, 0, dim_); + power_spectrum.Scale(scale); + power_spectrum(0) = power_spectrum(0) / 2; + power_spectrum(dim_ - 1) = power_spectrum(dim_ - 1) / 2; + power_spectrum.Add(1e-14); + power_spectrum.ApplyLog(); + output_row.CopyFromVec(power_spectrum); } return true; } diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index bbf8d6853..896c494dd 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -49,19 +49,15 @@ class LinearSpectrogram : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); } private: - void Hanning(std::vector* data) const; - bool Compute(const std::vector& waves, - std::vector>& feats); - bool NumpyFft(std::vector* v, - std::vector* real, - std::vector* img) const; + bool Compute(const kaldi::Vector& waves, + kaldi::Vector* feats); - kaldi::int32 fft_points_; size_t dim_; - std::vector hanning_window_; + kaldi::FeatureWindowFunction feature_window_funtion_; kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; std::unique_ptr base_extractor_; + kaldi::Vector reminded_wav_; int chunk_sample_size_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); };