diff --git a/.gitignore b/.gitignore
index 778824f5e..ad8e74925 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@
*.whl
*.egg-info
build
+*output/
docs/build/
docs/topic/ctc/warp-ctc/
@@ -33,6 +34,4 @@ tools/activate_python.sh
tools/miniconda.sh
tools/CRF++-0.58/
-speechx/fc_patch/
-
-*output/
+speechx/fc_patch/
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7fb01708a..09e92a667 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
- exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
+ exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
- exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
+ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6e8315e76..62fead470 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,4 +1,13 @@
# Changelog
+Date: 2022-3-08, Author: yt605155624.
+Add features to: T2S:
+ - Add aishell3 hifigan egs.
+ - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1545
+
+Date: 2022-3-08, Author: yt605155624.
+Add features to: T2S:
+ - Add vctk hifigan egs.
+ - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1544
Date: 2022-1-29, Author: yt605155624.
Add features to: T2S:
diff --git a/README.md b/README.md
index 46f492e99..ceef15af6 100644
--- a/README.md
+++ b/README.md
@@ -178,7 +178,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
-- 🤗 2021.12.14: Our PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/akhaliq/paddlespeech) Demos on Hugging Face Spaces are available!
+- 🤗 2021.12.14: Our PaddleSpeech [ASR](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR) and [TTS](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS) Demos on Hugging Face Spaces are available!
- 👏🏻 2021.12.10: PaddleSpeech CLI is available for Audio Classification, Automatic Speech Recognition, Speech Translation (English to Chinese) and Text-to-Speech.
### Community
@@ -207,6 +207,7 @@ paddlespeech cls --input input.wav
```shell
paddlespeech asr --lang zh --input input_16k.wav
```
+- web demo for Automatic Speech Recognition is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [ASR Demo](https://huggingface.co/spaces/KPatrick/PaddleSpeechASR)
**Speech Translation** (English to Chinese)
(not support for Mac and Windows now)
@@ -218,7 +219,7 @@ paddlespeech st --input input_16k.wav
```shell
paddlespeech tts --input "你好,欢迎使用飞桨深度学习框架!" --output output.wav
```
-- web demo for Text to Speech is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [TTS Demo](https://huggingface.co/spaces/akhaliq/paddlespeech)
+- web demo for Text to Speech is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See Demo: [TTS Demo](https://huggingface.co/spaces/KPatrick/PaddleSpeechTTS)
**Text Postprocessing**
- Punctuation Restoration
@@ -397,9 +398,9 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
HiFiGAN |
- CSMSC |
+ LJSpeech / VCTK / CSMSC / AISHELL-3 |
- HiFiGAN-csmsc
+ HiFiGAN-ljspeech / HiFiGAN-vctk / HiFiGAN-csmsc / HiFiGAN-aishell3
|
@@ -573,7 +574,6 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help.
-- Many thanks to [AK391](https://github.com/AK391) for TTS web demo on Huggingface Spaces using Gradio.
- Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files.
- Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function.
- Many thanks to [745165806](https://github.com/745165806)/[PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask) for contributing Punctuation Restoration model.
diff --git a/README_cn.md b/README_cn.md
index e84947372..8ea91e98d 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -392,9 +392,9 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
HiFiGAN |
- CSMSC |
+ LJSpeech / VCTK / CSMSC / AISHELL-3 |
- HiFiGAN-csmsc
+ HiFiGAN-ljspeech / HiFiGAN-vctk / HiFiGAN-csmsc / HiFiGAN-aishell3
|
diff --git a/demos/audio_searching/README.md b/demos/audio_searching/README.md
new file mode 100644
index 000000000..2b417c0eb
--- /dev/null
+++ b/demos/audio_searching/README.md
@@ -0,0 +1,171 @@
+([简体中文](./README_cn.md)|English)
+
+# Audio Searching
+
+## Introduction
+As the Internet continues to evolve, unstructured data such as emails, social media photos, live videos, and customer service voice calls have become increasingly common. If we want to process the data on a computer, we need to use embedding technology to transform the data into vector and store, index, and query it
+
+However, when there is a large amount of data, such as hundreds of millions of audio tracks, it is more difficult to do a similarity search. The exhaustive method is feasible, but very time consuming. For this scenario, this demo will introduce how to build an audio similarity retrieval system using the open source vector database Milvus
+
+Audio retrieval (speech, music, speaker, etc.) enables querying and finding similar sounds (or the same speaker) in a large amount of audio data. The audio similarity retrieval system can be used to identify similar sound effects, minimize intellectual property infringement, quickly retrieve the voice print library, and help enterprises control fraud and identity theft. Audio retrieval also plays an important role in the classification and statistical analysis of audio data
+
+In this demo, you will learn how to build an audio retrieval system to retrieve similar sound snippets. The uploaded audio clips are converted into vector data using paddlespeech-based pre-training models (audio classification model, speaker recognition model, etc.) and stored in Milvus. Milvus automatically generates a unique ID for each vector, then stores the ID and the corresponding audio information (audio ID, audio speaker ID, etc.) in MySQL to complete the library construction. During retrieval, users upload test audio to obtain vector, and then conduct vector similarity search in Milvus. The retrieval result returned by Milvus is vector ID, and the corresponding audio information can be queried in MySQL by ID
+
+
+
+Note:this demo uses the [CN-Celeb](http://openslr.org/82/) dataset of at least 650,000 audio entries and 3000 speakers to build the audio vector library, which is then retrieved using a preset distance calculation. The dataset can also use other, Adjust as needed, e.g. Librispeech, VoxCeleb, UrbanSound, GloVe, MNIST, etc
+
+## Usage
+### 1. Prepare MySQL and Milvus services by docker-compose
+The audio similarity search system requires Milvus, MySQL services. We can start these containers with one click through [docker-compose.yaml](./docker-compose.yaml), so please make sure you have [installed Docker Engine](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) before running. then
+
+```bash
+docker-compose -f docker-compose.yaml up -d
+```
+
+Then you will see the that all containers are created:
+
+```bash
+Creating network "quick_deploy_app_net" with driver "bridge"
+Creating milvus-minio ... done
+Creating milvus-etcd ... done
+Creating audio-mysql ... done
+Creating milvus-standalone ... done
+Creating audio-webclient ... done
+```
+
+And show all containers with `docker ps`, and you can use `docker logs audio-mysql` to get the logs of server container
+
+```bash
+CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
+b2bcf279e599 milvusdb/milvus:v2.0.1 "/tini -- milvus run…" 22 hours ago Up 22 hours 0.0.0.0:19530->19530/tcp milvus-standalone
+d8ef4c84e25c mysql:5.7 "docker-entrypoint.s…" 22 hours ago Up 22 hours 0.0.0.0:3306->3306/tcp, 33060/tcp audio-mysql
+8fb501edb4f3 quay.io/coreos/etcd:v3.5.0 "etcd -advertise-cli…" 22 hours ago Up 22 hours 2379-2380/tcp milvus-etcd
+ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 22 hours ago Up 22 hours (healthy) 9000/tcp milvus-minio
+15c84a506754 iregistry.baidu-int.com/paddlespeech/audio-search-client:1.0 "/bin/bash -c '/usr/…" 22 hours ago Up 22 hours (healthy) 0.0.0.0:8068->80/tcp audio-webclient
+```
+
+### 2. Start API Server
+Then to start the system server, and it provides HTTP backend services.
+
+- Install the Python packages
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+- Set configuration
+
+ ```bash
+ vim src/config.py
+ ```
+
+ Modify the parameters according to your own environment. Here listing some parameters that need to be set, for more information please refer to [config.py](./src/config.py).
+
+ | **Parameter** | **Description** | **Default setting** |
+ | ---------------- | ----------------------------------------------------- | ------------------- |
+ | MILVUS_HOST | The IP address of Milvus, you can get it by ifconfig. If running everything on one machine, most likely 127.0.0.1 | 127.0.0.1 |
+ | MILVUS_PORT | Port of Milvus. | 19530 |
+ | VECTOR_DIMENSION | Dimension of the vectors. | 2048 |
+ | MYSQL_HOST | The IP address of Mysql. | 127.0.0.1 |
+ | MYSQL_PORT | Port of Milvus. | 3306 |
+ | DEFAULT_TABLE | The milvus and mysql default collection name. | audio_table |
+
+- Run the code
+
+ Then start the server with Fastapi.
+
+ ```bash
+ export PYTHONPATH=$PYTHONPATH:./src
+ python src/main.py
+ ```
+
+ Then you will see the Application is started:
+
+ ```bash
+ INFO: Started server process [3949]
+ 2022-03-07 17:39:14,864 | INFO | server.py | serve | 75 | Started server process [3949]
+ INFO: Waiting for application startup.
+ 2022-03-07 17:39:14,865 | INFO | on.py | startup | 45 | Waiting for application startup.
+ INFO: Application startup complete.
+ 2022-03-07 17:39:14,866 | INFO | on.py | startup | 59 | Application startup complete.
+ INFO: Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ 2022-03-07 17:39:14,867 | INFO | server.py | _log_started_message | 206 | Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ ```
+
+### 3. Usage
+- Prepare data
+ ```bash
+ wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
+ ```
+ Note: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
+
+ - scripts test (recommend!)
+
+ The internal process is downloading data, loading the Paddlespeech model, extracting embedding, storing library, retrieving and deleting library
+ ```bash
+ python ./src/test_main.py
+ ```
+
+ Output:
+ ```bash
+ Checkpoint path: %your model path%
+ Extracting feature from audio No. 1 , 20 audios in total
+ Extracting feature from audio No. 2 , 20 audios in total
+ ...
+ 2022-03-09 17:22:13,870 | INFO | main.py | load_audios | 85 | Successfully loaded data, total count: 20
+ 2022-03-09 17:22:13,898 | INFO | main.py | count_audio | 147 | Successfully count the number of data!
+ 2022-03-09 17:22:13,918 | INFO | main.py | audio_path | 57 | Successfully load audio: ./example_audio/test.wav
+ ...
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/test.wav, distance 0.0
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_chopping.wav, distance 0.021805256605148315
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_cut_into_flesh.wav, distance 0.052762262523174286
+ ...
+ 2022-03-09 17:22:32,582 | INFO | main.py | search_local_audio | 135 | Successfully searched similar audio!
+ 2022-03-09 17:22:33,658 | INFO | main.py | drop_tables | 159 | Successfully drop tables in Milvus and MySQL!
+ ```
+- GUI test (optional)
+
+ Navigate to 127.0.0.1:8068 in your browser to access the front-end interface
+
+ Note: If the browser and the service are not on the same machine, then the IP needs to be changed to the IP of the machine where the service is located, and the corresponding API_URL in docker-compose.yaml needs to be changed and the service can be restarted
+
+ - Insert data
+
+ Download the data and decompress it to a path named /home/speech/data. Then enter /home/speech/data in the address bar of the upload page to upload the data
+
+ 
+
+ - Search for similar audio
+
+ Select the magnifying glass icon on the left side of the interface. Then, press the "Default Target Audio File" button and upload a .wav sound file you'd like to search. Results will be displayed
+
+ 
+
+### 4.Result
+
+ machine configuration:
+- OS: CentOS release 7.6
+- kernel:4.17.11-1.el7.elrepo.x86_64
+- CPU:Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
+- memory:132G
+
+dataset:
+- CN-Celeb, train size 650,000, test size 10,000, dimention 192, distance L2
+
+recall and elapsed time statistics are shown in the following figure:
+
+ 
+
+
+The retrieval framework based on Milvus takes about 2.9 milliseconds to retrieve on the premise of 90% recall rate, and it takes about 500 milliseconds for feature extraction (testing audio takes about 5 seconds), that is, a single audio test takes about 503 milliseconds in total, which can meet most application scenarios
+
+### 5.Pretrained Models
+
+Here is a list of pretrained models released by PaddleSpeech :
+
+| Model | Sample Rate
+| :--- | :---:
+| ecapa_tdnn | 16000
+| panns_cnn6| 32000
+| panns_cnn10| 32000
+| panns_cnn14| 32000
diff --git a/demos/audio_searching/README_cn.md b/demos/audio_searching/README_cn.md
new file mode 100644
index 000000000..d822c00df
--- /dev/null
+++ b/demos/audio_searching/README_cn.md
@@ -0,0 +1,172 @@
+
+(简体中文|[English](./README.md))
+
+# 音频相似性检索
+## 介绍
+
+随着互联网不断发展,电子邮件、社交媒体照片、直播视频、客服语音等非结构化数据已经变得越来越普遍。如果想要使用计算机来处理这些数据,需要使用 embedding 技术将这些数据转化为向量 vector,然后进行存储、建索引、并查询
+
+但是,当数据量很大,比如上亿条音频要做相似度搜索,就比较困难了。穷举法固然可行,但非常耗时。针对这种场景,该demo 将介绍如何使用开源向量数据库 Milvus 搭建音频相似度检索系统
+
+音频检索(如演讲、音乐、说话人等检索)实现了在海量音频数据中查询并找出相似声音(或相同说话人)片段。音频相似性检索系统可用于识别相似的音效、最大限度减少知识产权侵权等,还可以快速的检索声纹库、帮助企业控制欺诈和身份盗用等。在音频数据的分类和统计分析中,音频检索也发挥着重要作用
+
+在本 demo 中,你将学会如何构建一个音频检索系统,用来检索相似的声音片段。使用基于 PaddleSpeech 预训练模型(音频分类模型,说话人识别模型等)将上传的音频片段转换为向量数据,并存储在 Milvus 中。Milvus 自动为每个向量生成唯一的 ID,然后将 ID 和 相应的音频信息(音频id,音频的说话人id等等)存储在 MySQL,这样就完成建库的工作。用户在检索时,上传测试音频,得到向量,然后在 Milvus 中进行向量相似度搜索,Milvus 返回的检索结果为向量 ID,通过 ID 在 MySQL 内部查询相应的音频信息即可
+
+
+
+注:该 demo 使用 [CN-Celeb](http://openslr.org/82/) 数据集,包括至少 650000 条音频,3000 个说话人,来建立音频向量库(音频特征,或音频说话人特征),然后通过预设的距离计算方式进行音频(或说话人)检索,这里面数据集也可以使用其他的,根据需要调整,如Librispeech,VoxCeleb,UrbanSound,GloVe,MNIST等
+
+## 使用方法
+### 1. MySQL 和 Milvus 安装
+音频相似度搜索系统需要用到 Milvus, MySQL 服务。 我们可以通过 [docker-compose.yaml](./docker-compose.yaml) 一键启动这些容器,所以请确保在运行之前已经安装了 [Docker Engine](https://docs.docker.com/engine/install/) 和 [Docker Compose](https://docs.docker.com/compose/install/)。 即
+
+```bash
+docker-compose -f docker-compose.yaml up -d
+```
+
+然后你会看到所有的容器都被创建:
+
+```bash
+Creating network "quick_deploy_app_net" with driver "bridge"
+Creating milvus-minio ... done
+Creating milvus-etcd ... done
+Creating audio-mysql ... done
+Creating milvus-standalone ... done
+Creating audio-webclient ... done
+```
+
+可以采用'docker ps'来显示所有的容器,还可以使用'docker logs audio-mysql'来获取服务器容器的日志:
+
+```bash
+CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
+b2bcf279e599 milvusdb/milvus:v2.0.1 "/tini -- milvus run…" 22 hours ago Up 22 hours 0.0.0.0:19530->19530/tcp milvus-standalone
+d8ef4c84e25c mysql:5.7 "docker-entrypoint.s…" 22 hours ago Up 22 hours 0.0.0.0:3306->3306/tcp, 33060/tcp audio-mysql
+8fb501edb4f3 quay.io/coreos/etcd:v3.5.0 "etcd -advertise-cli…" 22 hours ago Up 22 hours 2379-2380/tcp milvus-etcd
+ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 22 hours ago Up 22 hours (healthy) 9000/tcp milvus-minio
+15c84a506754 iregistry.baidu-int.com/paddlespeech/audio-search-client:1.0 "/bin/bash -c '/usr/…" 22 hours ago Up 22 hours (healthy) 0.0.0.0:8068->80/tcp audio-webclient
+
+```
+
+### 2. 配置并启动 API 服务
+启动系统服务程序,它会提供基于 Http 后端服务
+
+- 安装服务依赖的 python 基础包
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+- 修改配置
+
+ ```bash
+ vim src/config.py
+ ```
+
+ 请根据实际环境进行修改。 这里列出了一些需要设置的参数,更多信息请参考 [config.py](./src/config.py)
+
+ | **Parameter** | **Description** | **Default setting** |
+ | ---------------- | ----------------------------------------------------- | ------------------- |
+ | MILVUS_HOST | The IP address of Milvus, you can get it by ifconfig. If running everything on one machine, most likely 127.0.0.1 | 127.0.0.1 |
+ | MILVUS_PORT | Port of Milvus. | 19530 |
+ | VECTOR_DIMENSION | Dimension of the vectors. | 2048 |
+ | MYSQL_HOST | The IP address of Mysql. | 127.0.0.1 |
+ | MYSQL_PORT | Port of Milvus. | 3306 |
+ | DEFAULT_TABLE | The milvus and mysql default collection name. | audio_table |
+
+- 运行程序
+
+ 启动用 Fastapi 构建的服务
+
+ ```bash
+ export PYTHONPATH=$PYTHONPATH:./src
+ python src/main.py
+ ```
+
+ 然后你会看到应用程序启动:
+
+ ```bash
+ INFO: Started server process [3949]
+ 2022-03-07 17:39:14,864 | INFO | server.py | serve | 75 | Started server process [3949]
+ INFO: Waiting for application startup.
+ 2022-03-07 17:39:14,865 | INFO | on.py | startup | 45 | Waiting for application startup.
+ INFO: Application startup complete.
+ 2022-03-07 17:39:14,866 | INFO | on.py | startup | 59 | Application startup complete.
+ INFO: Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ 2022-03-07 17:39:14,867 | INFO | server.py | _log_started_message | 206 | Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ ```
+
+### 3. 测试方法
+- 准备数据
+ ```bash
+ wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
+ ```
+ 注:如果希望快速搭建 demo,可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
+
+ - 脚本测试(推荐)
+
+ ```bash
+ python ./src/test_main.py
+ ```
+ 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库
+
+ 输出:
+ ```bash
+ Checkpoint path: %your model path%
+ Extracting feature from audio No. 1 , 20 audios in total
+ Extracting feature from audio No. 2 , 20 audios in total
+ ...
+ 2022-03-09 17:22:13,870 | INFO | main.py | load_audios | 85 | Successfully loaded data, total count: 20
+ 2022-03-09 17:22:13,898 | INFO | main.py | count_audio | 147 | Successfully count the number of data!
+ 2022-03-09 17:22:13,918 | INFO | main.py | audio_path | 57 | Successfully load audio: ./example_audio/test.wav
+ ...
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/test.wav, distance 0.0
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_chopping.wav, distance 0.021805256605148315
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_cut_into_flesh.wav, distance 0.052762262523174286
+ ...
+ 2022-03-09 17:22:32,582 | INFO | main.py | search_local_audio | 135 | Successfully searched similar audio!
+ 2022-03-09 17:22:33,658 | INFO | main.py | drop_tables | 159 | Successfully drop tables in Milvus and MySQL!
+ ```
+ - 前端测试(可选)
+
+ 在浏览器中输入 127.0.0.1:8068 访问前端页面
+
+ 注:如果浏览器和服务不在同一台机器上,那么 IP 需要修改成服务所在的机器 IP,并且docker-compose.yaml 中相应的 API_URL 也要修改,并重新起服务即可
+
+ - 上传音频
+
+ 下载数据并解压到一文件夹,假设为 /home/speech/data,那么在上传页面地址栏输入 /home/speech/data 进行数据上传
+
+ 
+
+ - 检索相似音频
+
+ 选择左上角放大镜,点击 “Default Target Audio File” 按钮,上传测试音频,接着你将看到检索结果
+
+ 
+
+### 4. 结果
+
+机器配置:
+- 操作系统: CentOS release 7.6
+- 内核:4.17.11-1.el7.elrepo.x86_64
+- 处理器:Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
+- 内存:132G
+
+数据集:
+- CN-Celeb, 训练集 65万, 测试集 1万,向量维度 192,距离计算方式 L2
+
+召回和耗时统计如下图:
+
+ 
+
+基于 milvus 的检索框架在召回率 90% 的前提下,检索耗时约 2.9 毫秒,加上特征提取(Embedding)耗时约 500毫秒(测试音频时长约 5秒),即单条音频测试总共耗时约 503 毫秒,可以满足大多数应用场景
+
+### 5. 预训练模型
+
+以下是 PaddleSpeech 提供的预训练模型列表:
+
+| 模型 | 采样率
+| :--- | :---:
+| ecapa_tdnn| 16000
+| panns_cnn6| 32000
+| panns_cnn10| 32000
+| panns_cnn14| 32000
diff --git a/demos/audio_searching/docker-compose.yaml b/demos/audio_searching/docker-compose.yaml
new file mode 100644
index 000000000..8916e76fd
--- /dev/null
+++ b/demos/audio_searching/docker-compose.yaml
@@ -0,0 +1,88 @@
+version: '3.5'
+
+services:
+ etcd:
+ container_name: milvus-etcd
+ image: quay.io/coreos/etcd:v3.5.0
+ networks:
+ app_net:
+ environment:
+ - ETCD_AUTO_COMPACTION_MODE=revision
+ - ETCD_AUTO_COMPACTION_RETENTION=1000
+ - ETCD_QUOTA_BACKEND_BYTES=4294967296
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+ command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+
+ minio:
+ container_name: milvus-minio
+ image: minio/minio:RELEASE.2020-12-03T00-03-10Z
+ networks:
+ app_net:
+ environment:
+ MINIO_ACCESS_KEY: minioadmin
+ MINIO_SECRET_KEY: minioadmin
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+ command: minio server /minio_data
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+ standalone:
+ container_name: milvus-standalone
+ image: milvusdb/milvus:v2.0.1
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.10
+ command: ["milvus", "run", "standalone"]
+ environment:
+ ETCD_ENDPOINTS: etcd:2379
+ MINIO_ADDRESS: minio:9000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+ ports:
+ - "19530:19530"
+ depends_on:
+ - "etcd"
+ - "minio"
+
+ mysql:
+ container_name: audio-mysql
+ image: mysql:5.7
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.11
+ environment:
+ - MYSQL_ROOT_PASSWORD=123456
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
+ ports:
+ - "3306:3306"
+
+ webclient:
+ container_name: audio-webclient
+ image: qingen1/paddlespeech-audio-search-client:2.3
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.13
+ environment:
+ API_URL: 'http://127.0.0.1:8002'
+ ports:
+ - "8068:80"
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost/"]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+networks:
+ app_net:
+ driver: bridge
+ ipam:
+ driver: default
+ config:
+ - subnet: 172.16.23.0/24
+ gateway: 172.16.23.1
diff --git a/demos/audio_searching/img/audio_searching.png b/demos/audio_searching/img/audio_searching.png
new file mode 100644
index 000000000..b145dd499
Binary files /dev/null and b/demos/audio_searching/img/audio_searching.png differ
diff --git a/demos/audio_searching/img/insert.png b/demos/audio_searching/img/insert.png
new file mode 100644
index 000000000..b9e766bda
Binary files /dev/null and b/demos/audio_searching/img/insert.png differ
diff --git a/demos/audio_searching/img/result.png b/demos/audio_searching/img/result.png
new file mode 100644
index 000000000..c4efc0c7f
Binary files /dev/null and b/demos/audio_searching/img/result.png differ
diff --git a/demos/audio_searching/img/search.png b/demos/audio_searching/img/search.png
new file mode 100644
index 000000000..26bcd9bdd
Binary files /dev/null and b/demos/audio_searching/img/search.png differ
diff --git a/demos/audio_searching/requirements.txt b/demos/audio_searching/requirements.txt
new file mode 100644
index 000000000..9e73361b4
--- /dev/null
+++ b/demos/audio_searching/requirements.txt
@@ -0,0 +1,12 @@
+soundfile==0.10.3.post1
+librosa==0.8.0
+numpy
+pymysql
+fastapi
+uvicorn
+diskcache==5.2.1
+pymilvus==2.0.1
+python-multipart
+typing
+starlette
+pydantic
\ No newline at end of file
diff --git a/demos/audio_searching/src/config.py b/demos/audio_searching/src/config.py
new file mode 100644
index 000000000..72a8fb4be
--- /dev/null
+++ b/demos/audio_searching/src/config.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+############### Milvus Configuration ###############
+MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
+MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
+VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
+INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
+METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
+DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
+TOP_K = int(os.getenv("TOP_K", "10"))
+
+############### MySQL Configuration ###############
+MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
+MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
+MYSQL_USER = os.getenv("MYSQL_USER", "root")
+MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
+MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
+
+############### Data Path ###############
+UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
+
+############### Number of Log Files ###############
+LOGS_NUM = int(os.getenv("logs_num", "0"))
diff --git a/demos/audio_searching/src/encode.py b/demos/audio_searching/src/encode.py
new file mode 100644
index 000000000..eba5c48c0
--- /dev/null
+++ b/demos/audio_searching/src/encode.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import librosa
+import numpy as np
+from logs import LOGGER
+
+
+def get_audio_embedding(path):
+ """
+ Use vpr_inference to generate embedding of audio
+ """
+ try:
+ RESAMPLE_RATE = 16000
+ audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
+
+ # TODO add infer/python interface to get embedding, now fake it by rand
+ # vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
+ # embedding = vpr.inference(audio)
+ np.random.seed(hash(os.path.basename(path)) % 1000000)
+ embedding = np.random.rand(1, 2048)
+ embedding = embedding / np.linalg.norm(embedding)
+ embedding = embedding.tolist()[0]
+ return embedding
+ except Exception as e:
+ LOGGER.error(f"Error with embedding:{e}")
+ return None
diff --git a/demos/audio_searching/src/logs.py b/demos/audio_searching/src/logs.py
new file mode 100644
index 000000000..ba3ed069c
--- /dev/null
+++ b/demos/audio_searching/src/logs.py
@@ -0,0 +1,164 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import datetime
+import logging
+import os
+import re
+import sys
+
+from config import LOGS_NUM
+
+
+class MultiprocessHandler(logging.FileHandler):
+ """
+ A handler class which writes formatted logging records to disk files
+ """
+
+ def __init__(self,
+ filename,
+ when='D',
+ backupCount=0,
+ encoding=None,
+ delay=False):
+ """
+ Open the specified file and use it as the stream for logging
+ """
+ self.prefix = filename
+ self.backupCount = backupCount
+ self.when = when.upper()
+ self.extMath = r"^\d{4}-\d{2}-\d{2}"
+
+ self.when_dict = {
+ 'S': "%Y-%m-%d-%H-%M-%S",
+ 'M': "%Y-%m-%d-%H-%M",
+ 'H': "%Y-%m-%d-%H",
+ 'D': "%Y-%m-%d"
+ }
+
+ self.suffix = self.when_dict.get(when)
+ if not self.suffix:
+ print('The specified date interval unit is invalid: ', self.when)
+ sys.exit(1)
+
+ self.filefmt = os.path.join('.', "logs",
+ f"{self.prefix}-{self.suffix}.log")
+
+ self.filePath = datetime.datetime.now().strftime(self.filefmt)
+
+ _dir = os.path.dirname(self.filefmt)
+ try:
+ if not os.path.exists(_dir):
+ os.makedirs(_dir)
+ except Exception as e:
+ print('Failed to create log file: ', e)
+ print("log_path:" + self.filePath)
+ sys.exit(1)
+
+ logging.FileHandler.__init__(self, self.filePath, 'a+', encoding, delay)
+
+ def should_change_file_to_write(self):
+ """
+ To write the file
+ """
+ _filePath = datetime.datetime.now().strftime(self.filefmt)
+ if _filePath != self.filePath:
+ self.filePath = _filePath
+ return True
+ return False
+
+ def do_change_file(self):
+ """
+ To change file states
+ """
+ self.baseFilename = os.path.abspath(self.filePath)
+ if self.stream:
+ self.stream.close()
+ self.stream = None
+
+ if not self.delay:
+ self.stream = self._open()
+ if self.backupCount > 0:
+ for s in self.get_files_to_delete():
+ os.remove(s)
+
+ def get_files_to_delete(self):
+ """
+ To delete backup files
+ """
+ dir_name, _ = os.path.split(self.baseFilename)
+ file_names = os.listdir(dir_name)
+ result = []
+ prefix = self.prefix + '-'
+ for file_name in file_names:
+ if file_name[:len(prefix)] == prefix:
+ suffix = file_name[len(prefix):-4]
+ if re.compile(self.extMath).match(suffix):
+ result.append(os.path.join(dir_name, file_name))
+ result.sort()
+
+ if len(result) < self.backupCount:
+ result = []
+ else:
+ result = result[:len(result) - self.backupCount]
+ return result
+
+ def emit(self, record):
+ """
+ Emit a record
+ """
+ try:
+ if self.should_change_file_to_write():
+ self.do_change_file()
+ logging.FileHandler.emit(self, record)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ self.handleError(record)
+
+
+def write_log():
+ """
+ Init a logger
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+ # formatter = '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(module)s | %(lineno)s | %(message)s'
+ fmt = logging.Formatter(
+ '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s'
+ )
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(logging.INFO)
+ stream_handler.setFormatter(fmt)
+
+ log_name = "audio-searching"
+ file_handler = MultiprocessHandler(log_name, when='D', backupCount=LOGS_NUM)
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(fmt)
+ file_handler.do_change_file()
+
+ logger.addHandler(stream_handler)
+ logger.addHandler(file_handler)
+
+ return logger
+
+
+LOGGER = write_log()
+
+if __name__ == "__main__":
+ message = 'test writing logs'
+ LOGGER.info(message)
+ LOGGER.debug(message)
+ LOGGER.error(message)
diff --git a/demos/audio_searching/src/main.py b/demos/audio_searching/src/main.py
new file mode 100644
index 000000000..db091a39d
--- /dev/null
+++ b/demos/audio_searching/src/main.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Optional
+
+import uvicorn
+from config import UPLOAD_PATH
+from diskcache import Cache
+from fastapi import FastAPI
+from fastapi import File
+from fastapi import UploadFile
+from logs import LOGGER
+from milvus_helpers import MilvusHelper
+from mysql_helpers import MySQLHelper
+from operations.count import do_count
+from operations.drop import do_drop
+from operations.load import do_load
+from operations.search import do_search
+from pydantic import BaseModel
+from starlette.middleware.cors import CORSMiddleware
+from starlette.requests import Request
+from starlette.responses import FileResponse
+
+app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"])
+
+MODEL = None
+MILVUS_CLI = MilvusHelper()
+MYSQL_CLI = MySQLHelper()
+
+# Mkdir 'tmp/audio-data'
+if not os.path.exists(UPLOAD_PATH):
+ os.makedirs(UPLOAD_PATH)
+ LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
+
+
+@app.get('/data')
+def audio_path(audio_path):
+ # Get the audio file
+ try:
+ LOGGER.info(f"Successfully load audio: {audio_path}")
+ return FileResponse(audio_path)
+ except Exception as e:
+ LOGGER.error(f"upload audio error: {e}")
+ return {'status': False, 'msg': e}, 400
+
+
+@app.get('/progress')
+def get_progress():
+ # Get the progress of dealing with data
+ try:
+ cache = Cache('./tmp')
+ return f"current: {cache['current']}, total: {cache['total']}"
+ except Exception as e:
+ LOGGER.error(f"Upload data error: {e}")
+ return {'status': False, 'msg': e}, 400
+
+
+class Item(BaseModel):
+ Table: Optional[str] = None
+ File: str
+
+
+@app.post('/audio/load')
+async def load_audios(item: Item):
+ # Insert all the audio files under the file path to Milvus/MySQL
+ try:
+ total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
+ LOGGER.info(f"Successfully loaded data, total count: {total_num}")
+ return {'status': True, 'msg': "Successfully loaded data!"}
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/search')
+async def search_audio(request: Request,
+ table_name: str=None,
+ audio: UploadFile=File(...)):
+ # Search the uploaded audio in Milvus/MySQL
+ try:
+ # Save the upload data to server.
+ content = await audio.read()
+ query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
+ with open(query_audio_path, "wb+") as f:
+ f.write(content)
+ host = request.headers['host']
+ _, paths, distances = do_search(host, table_name, query_audio_path,
+ MILVUS_CLI, MYSQL_CLI)
+ names = []
+ for path, score in zip(paths, distances):
+ names.append(os.path.basename(path))
+ LOGGER.info(f"search result {path}, score {score}")
+ res = dict(zip(paths, zip(names, distances)))
+ # Sort results by distance metric, closest distances first
+ res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
+ LOGGER.info("Successfully searched similar audio!")
+ return res
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/search/local')
+async def search_local_audio(request: Request,
+ query_audio_path: str,
+ table_name: str=None):
+ # Search the uploaded audio in Milvus/MySQL
+ try:
+ host = request.headers['host']
+ _, paths, distances = do_search(host, table_name, query_audio_path,
+ MILVUS_CLI, MYSQL_CLI)
+ names = []
+ for path, score in zip(paths, distances):
+ names.append(os.path.basename(path))
+ LOGGER.info(f"search result {path}, score {score}")
+ res = dict(zip(paths, zip(names, distances)))
+ # Sort results by distance metric, closest distances first
+ res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
+ LOGGER.info("Successfully searched similar audio!")
+ return res
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.get('/audio/count')
+async def count_audio(table_name: str=None):
+ # Returns the total number of vectors in the system
+ try:
+ num = do_count(table_name, MILVUS_CLI)
+ LOGGER.info("Successfully count the number of data!")
+ return num
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/drop')
+async def drop_tables(table_name: str=None):
+ # Delete the collection of Milvus and MySQL
+ try:
+ status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
+ LOGGER.info("Successfully drop tables in Milvus and MySQL!")
+ return status
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+if __name__ == '__main__':
+ uvicorn.run(app=app, host='0.0.0.0', port=8002)
diff --git a/demos/audio_searching/src/milvus_helpers.py b/demos/audio_searching/src/milvus_helpers.py
new file mode 100644
index 000000000..1699e892e
--- /dev/null
+++ b/demos/audio_searching/src/milvus_helpers.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import METRIC_TYPE
+from config import MILVUS_HOST
+from config import MILVUS_PORT
+from config import VECTOR_DIMENSION
+from logs import LOGGER
+from pymilvus import Collection
+from pymilvus import CollectionSchema
+from pymilvus import connections
+from pymilvus import DataType
+from pymilvus import FieldSchema
+from pymilvus import utility
+
+
+class MilvusHelper:
+ """
+ the basic operations of PyMilvus
+
+ # This example shows how to:
+ # 1. connect to Milvus server
+ # 2. create a collection
+ # 3. insert entities
+ # 4. create index
+ # 5. search
+ # 6. delete a collection
+
+ """
+
+ def __init__(self):
+ try:
+ self.collection = None
+ connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
+ LOGGER.debug(
+ f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
+ )
+ except Exception as e:
+ LOGGER.error(f"Failed to connect Milvus: {e}")
+ sys.exit(1)
+
+ def set_collection(self, collection_name):
+ try:
+ if self.has_collection(collection_name):
+ self.collection = Collection(name=collection_name)
+ else:
+ raise Exception(
+ f"There is no collection named:{collection_name}")
+ except Exception as e:
+ LOGGER.error(f"Failed to set collection in Milvus: {e}")
+ sys.exit(1)
+
+ def has_collection(self, collection_name):
+ # Return if Milvus has the collection
+ try:
+ return utility.has_collection(collection_name)
+ except Exception as e:
+ LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
+ sys.exit(1)
+
+ def create_collection(self, collection_name):
+ # Create milvus collection if not exists
+ try:
+ if not self.has_collection(collection_name):
+ field1 = FieldSchema(
+ name="id",
+ dtype=DataType.INT64,
+ descrition="int64",
+ is_primary=True,
+ auto_id=True)
+ field2 = FieldSchema(
+ name="embedding",
+ dtype=DataType.FLOAT_VECTOR,
+ descrition="speaker embeddings",
+ dim=VECTOR_DIMENSION,
+ is_primary=False)
+ schema = CollectionSchema(
+ fields=[field1, field2], description="embeddings info")
+ self.collection = Collection(
+ name=collection_name, schema=schema)
+ LOGGER.debug(f"Create Milvus collection: {collection_name}")
+ else:
+ self.set_collection(collection_name)
+ return "OK"
+ except Exception as e:
+ LOGGER.error(f"Failed to create collection in Milvus: {e}")
+ sys.exit(1)
+
+ def insert(self, collection_name, vectors):
+ # Batch insert vectors to milvus collection
+ try:
+ self.create_collection(collection_name)
+ data = [vectors]
+ self.set_collection(collection_name)
+ mr = self.collection.insert(data)
+ ids = mr.primary_keys
+ self.collection.load()
+ LOGGER.debug(
+ f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
+ )
+ return ids
+ except Exception as e:
+ LOGGER.error(f"Failed to insert data to Milvus: {e}")
+ sys.exit(1)
+
+ def create_index(self, collection_name):
+ # Create IVF_FLAT index on milvus collection
+ try:
+ self.set_collection(collection_name)
+ default_index = {
+ "index_type": "IVF_SQ8",
+ "metric_type": METRIC_TYPE,
+ "params": {
+ "nlist": 16384
+ }
+ }
+ status = self.collection.create_index(
+ field_name="embedding", index_params=default_index)
+ if not status.code:
+ LOGGER.debug(
+ f"Successfully create index in collection:{collection_name} with param:{default_index}"
+ )
+ return status
+ else:
+ raise Exception(status.message)
+ except Exception as e:
+ LOGGER.error(f"Failed to create index: {e}")
+ sys.exit(1)
+
+ def delete_collection(self, collection_name):
+ # Delete Milvus collection
+ try:
+ self.set_collection(collection_name)
+ self.collection.drop()
+ LOGGER.debug("Successfully drop collection!")
+ return "ok"
+ except Exception as e:
+ LOGGER.error(f"Failed to drop collection: {e}")
+ sys.exit(1)
+
+ def search_vectors(self, collection_name, vectors, top_k):
+ # Search vector in milvus collection
+ try:
+ self.set_collection(collection_name)
+ search_params = {
+ "metric_type": METRIC_TYPE,
+ "params": {
+ "nprobe": 16
+ }
+ }
+ res = self.collection.search(
+ vectors,
+ anns_field="embedding",
+ param=search_params,
+ limit=top_k)
+ LOGGER.debug(f"Successfully search in collection: {res}")
+ return res
+ except Exception as e:
+ LOGGER.error(f"Failed to search vectors in Milvus: {e}")
+ sys.exit(1)
+
+ def count(self, collection_name):
+ # Get the number of milvus collection
+ try:
+ self.set_collection(collection_name)
+ num = self.collection.num_entities
+ LOGGER.debug(
+ f"Successfully get the num:{num} of the collection:{collection_name}"
+ )
+ return num
+ except Exception as e:
+ LOGGER.error(f"Failed to count vectors in Milvus: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/mysql_helpers.py b/demos/audio_searching/src/mysql_helpers.py
new file mode 100644
index 000000000..303838399
--- /dev/null
+++ b/demos/audio_searching/src/mysql_helpers.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+import pymysql
+from config import MYSQL_DB
+from config import MYSQL_HOST
+from config import MYSQL_PORT
+from config import MYSQL_PWD
+from config import MYSQL_USER
+from logs import LOGGER
+
+
+class MySQLHelper():
+ """
+ the basic operations of PyMySQL
+
+ # This example shows how to:
+ # 1. connect to MySQL server
+ # 2. create a table
+ # 3. insert data to table
+ # 4. search by milvus ids
+ # 5. delete table
+ """
+
+ def __init__(self):
+ self.conn = pymysql.connect(
+ host=MYSQL_HOST,
+ user=MYSQL_USER,
+ port=MYSQL_PORT,
+ password=MYSQL_PWD,
+ database=MYSQL_DB,
+ local_infile=True)
+ self.cursor = self.conn.cursor()
+
+ def test_connection(self):
+ try:
+ self.conn.ping()
+ except Exception:
+ self.conn = pymysql.connect(
+ host=MYSQL_HOST,
+ user=MYSQL_USER,
+ port=MYSQL_PORT,
+ password=MYSQL_PWD,
+ database=MYSQL_DB,
+ local_infile=True)
+ self.cursor = self.conn.cursor()
+
+ def create_mysql_table(self, table_name):
+ # Create mysql table if not exists
+ self.test_connection()
+ sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
+ try:
+ self.cursor.execute(sql)
+ LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def load_data_to_mysql(self, table_name, data):
+ # Batch insert (Milvus_ids, img_path) to mysql
+ self.test_connection()
+ sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
+ try:
+ self.cursor.executemany(sql, data)
+ self.conn.commit()
+ LOGGER.debug(
+ f"MYSQL loads data to table: {table_name} successfully")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def search_by_milvus_ids(self, ids, table_name):
+ # Get the img_path according to the milvus ids
+ self.test_connection()
+ str_ids = str(ids).replace('[', '').replace(']', '')
+ sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
+ try:
+ self.cursor.execute(sql)
+ results = self.cursor.fetchall()
+ results = [res[0] for res in results]
+ LOGGER.debug("MYSQL search by milvus id.")
+ return results
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def delete_table(self, table_name):
+ # Delete mysql table if exists
+ self.test_connection()
+ sql = "drop table if exists " + table_name + ";"
+ try:
+ self.cursor.execute(sql)
+ LOGGER.debug(f"MYSQL delete table:{table_name}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def delete_all_data(self, table_name):
+ # Delete all the data in mysql table
+ self.test_connection()
+ sql = 'delete from ' + table_name + ';'
+ try:
+ self.cursor.execute(sql)
+ self.conn.commit()
+ LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def count_table(self, table_name):
+ # Get the number of mysql table
+ self.test_connection()
+ sql = "select count(milvus_id) from " + table_name + ";"
+ try:
+ self.cursor.execute(sql)
+ results = self.cursor.fetchall()
+ LOGGER.debug(f"MYSQL count table:{table_name}")
+ return results[0][0]
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
diff --git a/paddleaudio/__init__.py b/demos/audio_searching/src/operations/__init__.py
similarity index 82%
rename from paddleaudio/__init__.py
rename to demos/audio_searching/src/operations/__init__.py
index 2685cf57c..97043fd7b 100644
--- a/paddleaudio/__init__.py
+++ b/demos/audio_searching/src/operations/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .backends import *
-from .features import *
diff --git a/demos/audio_searching/src/operations/count.py b/demos/audio_searching/src/operations/count.py
new file mode 100644
index 000000000..9a1f42082
--- /dev/null
+++ b/demos/audio_searching/src/operations/count.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from logs import LOGGER
+
+
+def do_count(table_name, milvus_cli):
+ """
+ Returns the total number of vectors in the system
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ try:
+ if not milvus_cli.has_collection(table_name):
+ return None
+ num = milvus_cli.count(table_name)
+ return num
+ except Exception as e:
+ LOGGER.error(f"Error attempting to count table {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/operations/drop.py b/demos/audio_searching/src/operations/drop.py
new file mode 100644
index 000000000..f8278ddd0
--- /dev/null
+++ b/demos/audio_searching/src/operations/drop.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from logs import LOGGER
+
+
+def do_drop(table_name, milvus_cli, mysql_cli):
+ """
+ Delete the collection of Milvus and MySQL
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ try:
+ if not milvus_cli.has_collection(table_name):
+ return "Collection is not exist"
+ status = milvus_cli.delete_collection(table_name)
+ mysql_cli.delete_table(table_name)
+ return status
+ except Exception as e:
+ LOGGER.error(f"Error attempting to drop table: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py
new file mode 100644
index 000000000..7a295bf34
--- /dev/null
+++ b/demos/audio_searching/src/operations/load.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+from config import DEFAULT_TABLE
+from diskcache import Cache
+from encode import get_audio_embedding
+from logs import LOGGER
+
+
+def get_audios(path):
+ """
+ List all wav and aif files recursively under the path folder.
+ """
+ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
+ return [
+ item
+ for sublist in [[os.path.join(dir, file) for file in files]
+ for dir, _, files in list(os.walk(path))]
+ for item in sublist if os.path.splitext(item)[1] in supported_formats
+ ]
+
+
+def extract_features(audio_dir):
+ """
+ Get the vector of audio
+ """
+ try:
+ cache = Cache('./tmp')
+ feats = []
+ names = []
+ audio_list = get_audios(audio_dir)
+ total = len(audio_list)
+ cache['total'] = total
+ for i, audio_path in enumerate(audio_list):
+ norm_feat = get_audio_embedding(audio_path)
+ if norm_feat is None:
+ continue
+ feats.append(norm_feat)
+ names.append(audio_path.encode())
+ cache['current'] = i + 1
+ print(
+ f"Extracting feature from audio No. {i + 1} , {total} audios in total"
+ )
+ return feats, names
+ except Exception as e:
+ LOGGER.error(f"Error with extracting feature from audio {e}")
+ sys.exit(1)
+
+
+def format_data(ids, names):
+ """
+ Combine the id of the vector and the name of the audio into a list
+ """
+ data = []
+ for i in range(len(ids)):
+ value = (str(ids[i]), names[i])
+ data.append(value)
+ return data
+
+
+def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
+ """
+ Import vectors to Milvus and data to Mysql respectively
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ vectors, names = extract_features(audio_dir)
+ ids = milvus_cli.insert(table_name, vectors)
+ milvus_cli.create_index(table_name)
+ mysql_cli.create_mysql_table(table_name)
+ mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
+ return len(ids)
diff --git a/demos/audio_searching/src/operations/search.py b/demos/audio_searching/src/operations/search.py
new file mode 100644
index 000000000..9cf48abf9
--- /dev/null
+++ b/demos/audio_searching/src/operations/search.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from config import TOP_K
+from encode import get_audio_embedding
+from logs import LOGGER
+
+
+def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
+ """
+ Search the uploaded audio in Milvus/MySQL
+ """
+ try:
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ feat = get_audio_embedding(audio_path)
+ vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
+ vids = [str(x.id) for x in vectors[0]]
+ paths = mysql_cli.search_by_milvus_ids(vids, table_name)
+ distances = [x.distance for x in vectors[0]]
+ for i in range(len(paths)):
+ tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
+ paths[i] = tmp
+ distances[i] = (1 - distances[i]) * 100
+ return vids, paths, distances
+ except Exception as e:
+ LOGGER.error(f"Error with search: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/test_main.py b/demos/audio_searching/src/test_main.py
new file mode 100644
index 000000000..331208ff1
--- /dev/null
+++ b/demos/audio_searching/src/test_main.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import zipfile
+
+import gdown
+from fastapi.testclient import TestClient
+from main import app
+
+client = TestClient(app)
+
+
+def download_audio_data():
+ """
+ download audio data
+ """
+ url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
+ gdown.download(url)
+
+ with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
+ zip_ref.extractall('./example_audio')
+
+
+def test_drop():
+ """
+ Delete the collection of Milvus and MySQL
+ """
+ response = client.post("/audio/drop")
+ assert response.status_code == 200
+
+
+def test_load():
+ """
+ Insert all the audio files under the file path to Milvus/MySQL
+ """
+ response = client.post("/audio/load", json={"File": "./example_audio"})
+ assert response.status_code == 200
+ assert response.json() == {
+ 'status': True,
+ 'msg': "Successfully loaded data!"
+ }
+
+
+def test_progress():
+ """
+ Get the progress of dealing with data
+ """
+ response = client.get("/progress")
+ assert response.status_code == 200
+ assert response.json() == "current: 20, total: 20"
+
+
+def test_count():
+ """
+ Returns the total number of vectors in the system
+ """
+ response = client.get("audio/count")
+ assert response.status_code == 200
+ assert response.json() == 20
+
+
+def test_search():
+ """
+ Search the uploaded audio in Milvus/MySQL
+ """
+ response = client.post(
+ "/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
+ assert response.status_code == 200
+ assert len(response.json()) == 10
+
+
+def test_data():
+ """
+ Get the audio file
+ """
+ response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
+ assert response.status_code == 200
+
+
+if __name__ == "__main__":
+ download_audio_data()
+ test_load()
+ test_count()
+ test_search()
+ test_drop()
diff --git a/demos/speech_recognition/README.md b/demos/speech_recognition/README.md
index 5d964fcea..636548801 100644
--- a/demos/speech_recognition/README.md
+++ b/demos/speech_recognition/README.md
@@ -84,5 +84,8 @@ Here is a list of pretrained models released by PaddleSpeech that can be used by
| Model | Language | Sample Rate
| :--- | :---: | :---: |
-| conformer_wenetspeech| zh| 16000
-| transformer_librispeech| en| 16000
+| conformer_wenetspeech| zh| 16k
+| transformer_librispeech| en| 16k
+| deepspeech2offline_aishell| zh| 16k
+| deepspeech2online_aishell | zh | 16k
+|deepspeech2offline_librispeech|en| 16k
diff --git a/demos/speech_recognition/README_cn.md b/demos/speech_recognition/README_cn.md
index ba1f1d65c..8033dbd81 100644
--- a/demos/speech_recognition/README_cn.md
+++ b/demos/speech_recognition/README_cn.md
@@ -81,5 +81,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
| 模型 | 语言 | 采样率
| :--- | :---: | :---: |
-| conformer_wenetspeech| zh| 16000
-| transformer_librispeech| en| 16000
+| conformer_wenetspeech | zh | 16k
+| transformer_librispeech | en | 16k
+| deepspeech2offline_aishell| zh| 16k
+| deepspeech2online_aishell | zh | 16k
+| deepspeech2offline_librispeech | en | 16k
diff --git a/demos/speech_server/.gitignore b/demos/speech_server/.gitignore
new file mode 100644
index 000000000..d8dd7532a
--- /dev/null
+++ b/demos/speech_server/.gitignore
@@ -0,0 +1 @@
+*.wav
diff --git a/demos/speech_server/README.md b/demos/speech_server/README.md
index ac5cc4b00..10489e713 100644
--- a/demos/speech_server/README.md
+++ b/demos/speech_server/README.md
@@ -10,21 +10,15 @@ This demo is an implementation of starting the voice service and accessing the s
### 1. Installation
see [installation](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
-You can choose one way from easy, meduim and hard to install paddlespeech.
+It is recommended to use **paddlepaddle 2.2.1** or above.
+You can choose one way from meduim and hard to install paddlespeech.
### 2. Prepare config File
-The configuration file contains the service-related configuration files and the model configuration related to the voice tasks contained in the service. They are all under the `conf` folder.
+The configuration file can be found in `conf/application.yaml` .
+Among them, `engine_list` indicates the speech engine that will be included in the service to be started, in the format of _.
+At present, the speech tasks integrated by the service include: asr (speech recognition) and tts (speech synthesis).
+Currently the engine type supports two forms: python and inference (Paddle Inference)
-**Note: The configuration of `engine_backend` in `application.yaml` represents all speech tasks included in the started service. **
-If the service you want to start contains only a certain speech task, then you need to comment out the speech tasks that do not need to be included. For example, if you only want to use the speech recognition (ASR) service, then you can comment out the speech synthesis (TTS) service, as in the following example:
-```bash
-engine_backend:
- asr: 'conf/asr/asr.yaml'
- #tts: 'conf/tts/tts.yaml'
-```
-
-**Note: The configuration file of `engine_backend` in `application.yaml` needs to match the configuration type of `engine_type`. **
-When the configuration file of `engine_backend` is `XXX.yaml`, the configuration type of `engine_type` needs to be set to `python`; when the configuration file of `engine_backend` is `XXX_pd.yaml`, the configuration of `engine_type` needs to be set type is `inference`;
The input of ASR client demo should be a WAV file(`.wav`), and the sample rate must be the same as the model.
@@ -116,21 +110,22 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API
```python
from paddlespeech.server.bin.paddlespeech_client import ASRClientExecutor
+ import json
asrclient_executor = ASRClientExecutor()
- asrclient_executor(
+ res = asrclient_executor(
input="./zh.wav",
server_ip="127.0.0.1",
port=8090,
sample_rate=16000,
lang="zh_cn",
audio_format="wav")
+ print(res.json())
```
Output:
```bash
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'transcription': '我认为跑步最重要的就是给我带来了身体健康'}}
- time cost 0.604353 s.
```
### 5. TTS Client Usage
@@ -152,7 +147,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- `speed`: Audio speed, the value should be set between 0 and 3. Default: 1.0
- `volume`: Audio volume, the value should be set between 0 and 3. Default: 1.0
- `sample_rate`: Sampling rate, choice: [0, 8000, 16000], the default is the same as the model. Default: 0
- - `output`: Output wave filepath. Default: `output.wav`.
+ - `output`: Output wave filepath. Default: None, which means not to save the audio to the local.
Output:
```bash
@@ -166,9 +161,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API
```python
from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor
+ import json
ttsclient_executor = TTSClientExecutor()
- ttsclient_executor(
+ res = ttsclient_executor(
input="您好,欢迎使用百度飞桨语音合成服务。",
server_ip="127.0.0.1",
port=8090,
@@ -177,6 +173,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
volume=1.0,
sample_rate=0,
output="./output.wav")
+
+ response_dict = res.json()
+ print(response_dict["message"])
+ print("Save synthesized audio successfully on %s." % (response_dict['result']['save_path']))
+ print("Audio duration: %f s." %(response_dict['result']['duration']))
```
Output:
@@ -184,7 +185,52 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
{'description': 'success.'}
Save synthesized audio successfully on ./output.wav.
Audio duration: 3.612500 s.
- Response time: 0.388317 s.
+
+ ```
+
+### 6. CLS Client Usage
+**Note:** The response time will be slightly longer when using the client for the first time
+- Command Line (Recommended)
+ ```
+ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
+ ```
+
+ Usage:
+
+ ```bash
+ paddlespeech_client cls --help
+ ```
+ Arguments:
+ - `server_ip`: server ip. Default: 127.0.0.1
+ - `port`: server port. Default: 8090
+ - `input`(required): Audio file to be classified.
+ - `topk`: topk scores of classification result.
+
+ Output:
+ ```bash
+ [2022-03-09 20:44:39,974] [ INFO] - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}}
+ [2022-03-09 20:44:39,975] [ INFO] - Response time 0.104360 s.
+
+
+ ```
+
+- Python API
+ ```python
+ from paddlespeech.server.bin.paddlespeech_client import CLSClientExecutor
+ import json
+
+ clsclient_executor = CLSClientExecutor()
+ res = clsclient_executor(
+ input="./zh.wav",
+ server_ip="127.0.0.1",
+ port=8090,
+ topk=1)
+ print(res.json())
+ ```
+
+ Output:
+ ```bash
+ {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}}
```
@@ -195,3 +241,6 @@ Get all models supported by the ASR service via `paddlespeech_server stats --tas
### TTS model
Get all models supported by the TTS service via `paddlespeech_server stats --task tts`, where static models can be used for paddle inference inference.
+
+### CLS model
+Get all models supported by the CLS service via `paddlespeech_server stats --task cls`, where static models can be used for paddle inference inference.
diff --git a/demos/speech_server/README_cn.md b/demos/speech_server/README_cn.md
index f202a30cd..2bd8af6c9 100644
--- a/demos/speech_server/README_cn.md
+++ b/demos/speech_server/README_cn.md
@@ -10,19 +10,16 @@
### 1. 安装
请看 [安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
-你可以从 easy,medium,hard 三中方式中选择一种方式安装 PaddleSpeech。
+推荐使用 **paddlepaddle 2.2.1** 或以上版本。
+你可以从 medium,hard 三中方式中选择一种方式安装 PaddleSpeech。
+
### 2. 准备配置文件
-配置文件包含服务相关的配置文件和服务中包含的语音任务相关的模型配置。 它们都在 `conf` 文件夹下。
-**注意:`application.yaml` 中 `engine_backend` 的配置表示启动的服务中包含的所有语音任务。**
-如果你想启动的服务中只包含某项语音任务,那么你需要注释掉不需要包含的语音任务。例如你只想使用语音识别(ASR)服务,那么你可以将语音合成(TTS)服务注释掉,如下示例:
-```bash
-engine_backend:
- asr: 'conf/asr/asr.yaml'
- #tts: 'conf/tts/tts.yaml'
-```
-**注意:`application.yaml` 中 `engine_backend` 的配置文件需要和 `engine_type` 的配置类型匹配。**
-当`engine_backend` 的配置文件为`XXX.yaml`时,需要设置`engine_type`的配置类型为`python`;当`engine_backend` 的配置文件为`XXX_pd.yaml`时,需要设置`engine_type`的配置类型为`inference`;
+配置文件可参见 `conf/application.yaml` 。
+其中,`engine_list`表示即将启动的服务将会包含的语音引擎,格式为 <语音任务>_<引擎类型>。
+目前服务集成的语音任务有: asr(语音识别)、tts(语音合成)。
+目前引擎类型支持两种形式:python 及 inference (Paddle Inference)
+
这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
@@ -83,8 +80,8 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
```
-### 4. ASR客户端使用方法
-**注意:**初次使用客户端时响应时间会略长
+### 4. ASR 客户端使用方法
+**注意:** 初次使用客户端时响应时间会略长
- 命令行 (推荐使用)
```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
@@ -114,29 +111,32 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API
```python
from paddlespeech.server.bin.paddlespeech_client import ASRClientExecutor
+ import json
asrclient_executor = ASRClientExecutor()
- asrclient_executor(
+ res = asrclient_executor(
input="./zh.wav",
server_ip="127.0.0.1",
port=8090,
sample_rate=16000,
lang="zh_cn",
audio_format="wav")
+ print(res.json())
```
输出:
```bash
{'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'transcription': '我认为跑步最重要的就是给我带来了身体健康'}}
- time cost 0.604353 s.
```
-### 5. TTS客户端使用方法
-**注意:**初次使用客户端时响应时间会略长
- ```bash
- paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav
- ```
+### 5. TTS 客户端使用方法
+**注意:** 初次使用客户端时响应时间会略长
+- 命令行 (推荐使用)
+
+ ```bash
+ paddlespeech_client tts --server_ip 127.0.0.1 --port 8090 --input "您好,欢迎使用百度飞桨语音合成服务。" --output output.wav
+ ```
使用帮助:
```bash
@@ -151,7 +151,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- `speed`: 音频速度,该值应设置在 0 到 3 之间。 默认值:1.0
- `volume`: 音频音量,该值应设置在 0 到 3 之间。 默认值: 1.0
- `sample_rate`: 采样率,可选 [0, 8000, 16000],默认与模型相同。 默认值:0
- - `output`: 输出音频的路径, 默认值:output.wav。
+ - `output`: 输出音频的路径, 默认值:None,表示不保存音频到本地。
输出:
```bash
@@ -164,9 +164,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- Python API
```python
from paddlespeech.server.bin.paddlespeech_client import TTSClientExecutor
+ import json
ttsclient_executor = TTSClientExecutor()
- ttsclient_executor(
+ res = ttsclient_executor(
input="您好,欢迎使用百度飞桨语音合成服务。",
server_ip="127.0.0.1",
port=8090,
@@ -175,6 +176,11 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
volume=1.0,
sample_rate=0,
output="./output.wav")
+
+ response_dict = res.json()
+ print(response_dict["message"])
+ print("Save synthesized audio successfully on %s." % (response_dict['result']['save_path']))
+ print("Audio duration: %f s." %(response_dict['result']['duration']))
```
输出:
@@ -182,13 +188,63 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
{'description': 'success.'}
Save synthesized audio successfully on ./output.wav.
Audio duration: 3.612500 s.
- Response time: 0.388317 s.
```
+ ### 5. CLS 客户端使用方法
+ **注意:** 初次使用客户端时响应时间会略长
+ - 命令行 (推荐使用)
+ ```
+ paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
+ ```
+
+ 使用帮助:
+
+ ```bash
+ paddlespeech_client cls --help
+ ```
+ 参数:
+ - `server_ip`: 服务端ip地址,默认: 127.0.0.1。
+ - `port`: 服务端口,默认: 8090。
+ - `input`(必须输入): 用于分类的音频文件。
+ - `topk`: 分类结果的topk。
+
+ 输出:
+ ```bash
+ [2022-03-09 20:44:39,974] [ INFO] - {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}}
+ [2022-03-09 20:44:39,975] [ INFO] - Response time 0.104360 s.
+
+
+ ```
+
+- Python API
+ ```python
+ from paddlespeech.server.bin.paddlespeech_client import CLSClientExecutor
+ import json
+
+ clsclient_executor = CLSClientExecutor()
+ res = clsclient_executor(
+ input="./zh.wav",
+ server_ip="127.0.0.1",
+ port=8090,
+ topk=1)
+ print(res.json())
+
+ ```
+
+ 输出:
+ ```bash
+ {'success': True, 'code': 200, 'message': {'description': 'success'}, 'result': {'topk': 1, 'results': [{'class_name': 'Speech', 'prob': 0.9027184844017029}]}}
+
+ ```
+
+
## 服务支持的模型
### ASR支持的模型
通过 `paddlespeech_server stats --task asr` 获取ASR服务支持的所有模型,其中静态模型可用于 paddle inference 推理。
### TTS支持的模型
通过 `paddlespeech_server stats --task tts` 获取TTS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。
+
+### CLS支持的模型
+通过 `paddlespeech_server stats --task cls` 获取CLS服务支持的所有模型,其中静态模型可用于 paddle inference 推理。
diff --git a/demos/speech_server/cls_client.sh b/demos/speech_server/cls_client.sh
new file mode 100644
index 000000000..5797aa204
--- /dev/null
+++ b/demos/speech_server/cls_client.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+
+wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
+paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav --topk 1
diff --git a/demos/speech_server/conf/application.yaml b/demos/speech_server/conf/application.yaml
index 6dcae74a9..2b1a05998 100644
--- a/demos/speech_server/conf/application.yaml
+++ b/demos/speech_server/conf/application.yaml
@@ -1,27 +1,137 @@
# This is the parameter configuration file for PaddleSpeech Serving.
-##################################################################
-# SERVER SETTING #
-##################################################################
-host: '127.0.0.1'
+#################################################################################
+# SERVER SETTING #
+#################################################################################
+host: 127.0.0.1
port: 8090
-##################################################################
-# CONFIG FILE #
-##################################################################
-# add engine backend type (Options: asr, tts) and config file here.
-# Adding a speech task to engine_backend means starting the service.
-engine_backend:
- asr: 'conf/asr/asr.yaml'
- tts: 'conf/tts/tts.yaml'
-
-# The engine_type of speech task needs to keep the same type as the config file of speech task.
-# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
-# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
-#
-# add engine type (Options: python, inference)
-engine_type:
- asr: 'python'
- tts: 'python'
+# The task format in the engin_list is: _
+# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
+engine_list: ['asr_python', 'tts_python', 'cls_python']
+
+
+#################################################################################
+# ENGINE CONFIG #
+#################################################################################
+
+################################### ASR #########################################
+################### speech task: asr; engine_type: python #######################
+asr_python:
+ model: 'conformer_wenetspeech'
+ lang: 'zh'
+ sample_rate: 16000
+ cfg_path: # [optional]
+ ckpt_path: # [optional]
+ decode_method: 'attention_rescoring'
+ force_yes: True
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: asr; engine_type: inference #######################
+asr_inference:
+ # model_type choices=['deepspeech2offline_aishell']
+ model_type: 'deepspeech2offline_aishell'
+ am_model: # the pdmodel file of am static model [optional]
+ am_params: # the pdiparams file of am static model [optional]
+ lang: 'zh'
+ sample_rate: 16000
+ cfg_path:
+ decode_method:
+ force_yes: True
+
+ am_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+
+################################### TTS #########################################
+################### speech task: tts; engine_type: python #######################
+tts_python:
+ # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
+ # 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
+ # 'fastspeech2_vctk']
+ am: 'fastspeech2_csmsc'
+ am_config:
+ am_ckpt:
+ am_stat:
+ phones_dict:
+ tones_dict:
+ speaker_dict:
+ spk_id: 0
+
+ # voc (vocoder) choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
+ # 'pwgan_vctk', 'mb_melgan_csmsc']
+ voc: 'pwgan_csmsc'
+ voc_config:
+ voc_ckpt:
+ voc_stat:
+
+ # others
+ lang: 'zh'
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: tts; engine_type: inference #######################
+tts_inference:
+ # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
+ am: 'fastspeech2_csmsc'
+ am_model: # the pdmodel file of your am static model (XX.pdmodel)
+ am_params: # the pdiparams file of your am static model (XX.pdipparams)
+ am_sample_rate: 24000
+ phones_dict:
+ tones_dict:
+ speaker_dict:
+ spk_id: 0
+
+ am_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+ # voc (vocoder) choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
+ voc: 'pwgan_csmsc'
+ voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
+ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
+ voc_sample_rate: 24000
+
+ voc_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+ # others
+ lang: 'zh'
+
+
+################################### CLS #########################################
+################### speech task: cls; engine_type: python #######################
+cls_python:
+ # model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
+ model: 'panns_cnn14'
+ cfg_path: # [optional] Config of cls task.
+ ckpt_path: # [optional] Checkpoint file of model.
+ label_file: # [optional] Label file of cls task.
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: cls; engine_type: inference #######################
+cls_inference:
+ # model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
+ model_type: 'panns_cnn14'
+ cfg_path:
+ model_path: # the pdmodel file of am static model [optional]
+ params_path: # the pdiparams file of am static model [optional]
+ label_file: # [optional] Label file of cls task.
+
+ predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
diff --git a/demos/speech_server/conf/asr/asr.yaml b/demos/speech_server/conf/asr/asr.yaml
deleted file mode 100644
index a6743b775..000000000
--- a/demos/speech_server/conf/asr/asr.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-model: 'conformer_wenetspeech'
-lang: 'zh'
-sample_rate: 16000
-cfg_path: # [optional]
-ckpt_path: # [optional]
-decode_method: 'attention_rescoring'
-force_yes: True
-device: # set 'gpu:id' or 'cpu'
diff --git a/demos/speech_server/conf/asr/asr_pd.yaml b/demos/speech_server/conf/asr/asr_pd.yaml
deleted file mode 100644
index 4c415ac79..000000000
--- a/demos/speech_server/conf/asr/asr_pd.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-# This is the parameter configuration file for ASR server.
-# These are the static models that support paddle inference.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['deepspeech2offline_aishell'] TODO
-##################################################################
-model_type: 'deepspeech2offline_aishell'
-am_model: # the pdmodel file of am static model [optional]
-am_params: # the pdiparams file of am static model [optional]
-lang: 'zh'
-sample_rate: 16000
-cfg_path:
-decode_method:
-force_yes: True
-
-am_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-
-##################################################################
-# OTHERS #
-##################################################################
diff --git a/demos/speech_server/conf/tts/tts.yaml b/demos/speech_server/conf/tts/tts.yaml
deleted file mode 100644
index 19207f0b0..000000000
--- a/demos/speech_server/conf/tts/tts.yaml
+++ /dev/null
@@ -1,32 +0,0 @@
-# This is the parameter configuration file for TTS server.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
-# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
-# 'fastspeech2_vctk']
-##################################################################
-am: 'fastspeech2_csmsc'
-am_config:
-am_ckpt:
-am_stat:
-phones_dict:
-tones_dict:
-speaker_dict:
-spk_id: 0
-
-##################################################################
-# VOCODER SETTING #
-# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
-# 'pwgan_vctk', 'mb_melgan_csmsc']
-##################################################################
-voc: 'pwgan_csmsc'
-voc_config:
-voc_ckpt:
-voc_stat:
-
-##################################################################
-# OTHERS #
-##################################################################
-lang: 'zh'
-device: # set 'gpu:id' or 'cpu'
diff --git a/demos/speech_server/conf/tts/tts_pd.yaml b/demos/speech_server/conf/tts/tts_pd.yaml
deleted file mode 100644
index e27b9665b..000000000
--- a/demos/speech_server/conf/tts/tts_pd.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# This is the parameter configuration file for TTS server.
-# These are the static models that support paddle inference.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
-##################################################################
-am: 'fastspeech2_csmsc'
-am_model: # the pdmodel file of your am static model (XX.pdmodel)
-am_params: # the pdiparams file of your am static model (XX.pdipparams)
-am_sample_rate: 24000
-phones_dict:
-tones_dict:
-speaker_dict:
-spk_id: 0
-
-am_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-
-##################################################################
-# VOCODER SETTING #
-# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
-##################################################################
-voc: 'pwgan_csmsc'
-voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
-voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
-voc_sample_rate: 24000
-
-voc_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-##################################################################
-# OTHERS #
-##################################################################
-lang: 'zh'
diff --git a/demos/speech_server/server.sh b/demos/speech_server/server.sh
index d9367ec06..e5961286b 100644
--- a/demos/speech_server/server.sh
+++ b/demos/speech_server/server.sh
@@ -1,3 +1,3 @@
#!/bin/bash
-paddlespeech_server start --config_file ./conf/application.yaml
\ No newline at end of file
+paddlespeech_server start --config_file ./conf/application.yaml
diff --git a/docs/source/reference.md b/docs/source/reference.md
index a8327e92e..f1a02d200 100644
--- a/docs/source/reference.md
+++ b/docs/source/reference.md
@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md)
- ISC License
- Audio feature
+
+* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
+- zlib License
+- ThreadPool
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index 8f855f7cf..52b386daf 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -49,17 +49,18 @@ Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size (stat
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)|||
Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)|5.1MB|
Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|||
-Parallel WaveGAN|AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)|||
+Parallel WaveGAN| AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)|||
Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip)|||
|Multi Band MelGAN | CSMSC |[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
[mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)|[mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) |8.2MB|
Style MelGAN | CSMSC |[Style MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc4)|[style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip)| | |
HiFiGAN | CSMSC |[HiFiGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc5)|[hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)|[hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip)|50MB|
+HiFiGAN | AISHELL-3 |[HiFiGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)|||
WaveRNN | CSMSC |[WaveRNN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc6)|[wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip)|[wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip)|18MB|
### Voice Cloning
Model Type | Dataset| Example Link | Pretrained Models
-:-------------:| :------------:| :-----: | :-----:
+:-------------:| :------------:| :-----: | :-----: |
GE2E| AISHELL-3, etc. |[ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e)|[ge2e_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/ge2e/ge2e_ckpt_0.3.zip)
GE2E + Tactron2| AISHELL-3 |[ge2e-tactron2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vc0)|[tacotron2_aishell3_ckpt_vc0_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_aishell3_ckpt_vc0_0.2.0.zip)
GE2E + FastSpeech2 | AISHELL-3 |[ge2e-fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vc1)|[fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_vc1_ckpt_0.5.zip)
@@ -67,9 +68,9 @@ GE2E + FastSpeech2 | AISHELL-3 |[ge2e-fastspeech2-aishell3](https://github.com/
## Audio Classification Models
-Model Type | Dataset| Example Link | Pretrained Models
-:-------------:| :------------:| :-----: | :-----:
-PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams)
+Model Type | Dataset| Example Link | Pretrained Models | Static Models
+:-------------:| :------------:| :-----: | :-----: | :-----:
+PANN | Audioset| [audioset_tagging_cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn) | [panns_cnn6.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn6.pdparams), [panns_cnn10.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn10.pdparams), [panns_cnn14.pdparams](https://bj.bcebos.com/paddleaudio/models/panns_cnn14.pdparams) | [panns_cnn6_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz)(18M), [panns_cnn10_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz)(19M), [panns_cnn14_static.tar.gz](https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz)(289M)
PANN | ESC-50 |[pann-esc50](../../examples/esc50/cls0)|[esc50_cnn6.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn6.tar.gz), [esc50_cnn10.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn10.tar.gz), [esc50_cnn14.tar.gz](https://paddlespeech.bj.bcebos.com/cls/esc50/esc50_cnn14.tar.gz)
## Punctuation Restoration Models
diff --git a/examples/aishell3/tts3/local/synthesize.sh b/examples/aishell3/tts3/local/synthesize.sh
index b1fc96a2d..d3978833f 100755
--- a/examples/aishell3/tts3/local/synthesize.sh
+++ b/examples/aishell3/tts3/local/synthesize.sh
@@ -4,18 +4,44 @@ config_path=$1
train_output_path=$2
ckpt_name=$3
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/../synthesize.py \
- --am=fastspeech2_aishell3 \
- --am_config=${config_path} \
- --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
- --voc=pwgan_aishell3 \
- --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
- --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
- --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
- --test_metadata=dump/test/norm/metadata.jsonl \
- --output_dir=${train_output_path}/test \
- --phones_dict=dump/phone_id_map.txt \
- --speaker_dict=dump/speaker_id_map.txt
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=pwgan_aishell3 \
+ --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
+ --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
+ --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=hifigan_aishell3 \
+ --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pd \
+ --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt
+fi
+
diff --git a/examples/aishell3/tts3/local/synthesize_e2e.sh b/examples/aishell3/tts3/local/synthesize_e2e.sh
index 60e1a5cee..ff3608be7 100755
--- a/examples/aishell3/tts3/local/synthesize_e2e.sh
+++ b/examples/aishell3/tts3/local/synthesize_e2e.sh
@@ -4,21 +4,50 @@ config_path=$1
train_output_path=$2
ckpt_name=$3
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/../synthesize_e2e.py \
- --am=fastspeech2_aishell3 \
- --am_config=${config_path} \
- --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
- --voc=pwgan_aishell3 \
- --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
- --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
- --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
- --lang=zh \
- --text=${BIN_DIR}/../sentences.txt \
- --output_dir=${train_output_path}/test_e2e \
- --phones_dict=dump/phone_id_map.txt \
- --speaker_dict=dump/speaker_id_map.txt \
- --spk_id=0 \
- --inference_dir=${train_output_path}/inference
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=pwgan_aishell3 \
+ --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
+ --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
+ --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0 \
+ --inference_dir=${train_output_path}/inference
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "in hifigan syn_e2e"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=fastspeech2_nosil_aishell3_ckpt_0.4/speech_stats.npy \
+ --voc=hifigan_aishell3 \
+ --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \
+ --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \
+ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \
+ --spk_id=0 \
+ --inference_dir=${train_output_path}/inference
+ fi
diff --git a/examples/aishell3/vc0/local/preprocess.sh b/examples/aishell3/vc0/local/preprocess.sh
index 069cf94c4..e458c7063 100755
--- a/examples/aishell3/vc0/local/preprocess.sh
+++ b/examples/aishell3/vc0/local/preprocess.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-stage=3
+stage=0
stop_stage=100
config_path=$1
diff --git a/examples/aishell3/voc1/run.sh b/examples/aishell3/voc1/run.sh
index 4f426ea02..cab1ac38b 100755
--- a/examples/aishell3/voc1/run.sh
+++ b/examples/aishell3/voc1/run.sh
@@ -3,7 +3,7 @@
set -e
source path.sh
-gpus=0
+gpus=0,1
stage=0
stop_stage=100
diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md
new file mode 100644
index 000000000..ebe2530be
--- /dev/null
+++ b/examples/aishell3/voc5/README.md
@@ -0,0 +1,156 @@
+# HiFiGAN with AISHELL-3
+This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [AISHELL-3](http://www.aishelltech.com/aishell_3).
+
+AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems.
+## Dataset
+### Download and Extract
+Download AISHELL-3.
+```bash
+wget https://www.openslr.org/resources/93/data_aishell3.tgz
+```
+Extract AISHELL-3.
+```bash
+mkdir data_aishell3
+tar zxvf data_aishell3.tgz -C data_aishell3
+```
+### Get MFA Result and Extract
+We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
+You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
+
+## Get Started
+Assume the path to the dataset is `~/datasets/data_aishell3`.
+Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
+Run the command below to
+1. **source path**.
+2. preprocess the dataset.
+3. train the model.
+4. synthesize wavs.
+ - synthesize waveform from `metadata.jsonl`.
+```bash
+./run.sh
+```
+You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, run the following command will only preprocess the dataset.
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Data Preprocessing
+```bash
+./local/preprocess.sh ${conf_path}
+```
+When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
+
+```text
+dump
+├── dev
+│ ├── norm
+│ └── raw
+├── test
+│ ├── norm
+│ └── raw
+└── train
+ ├── norm
+ ├── raw
+ └── feats_stats.npy
+```
+
+The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
+
+Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
+
+### Model Training
+```bash
+CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
+```
+`./local/train.sh` calls `${BIN_DIR}/train.py`.
+Here's the complete help message.
+
+```text
+usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
+ [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
+ [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
+ [--run-benchmark RUN_BENCHMARK]
+ [--profiler_options PROFILER_OPTIONS]
+
+Train a ParallelWaveGAN model.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --config CONFIG config file to overwrite default config.
+ --train-metadata TRAIN_METADATA
+ training data.
+ --dev-metadata DEV_METADATA
+ dev data.
+ --output-dir OUTPUT_DIR
+ output dir.
+ --ngpu NGPU if ngpu == 0, use cpu.
+
+benchmark:
+ arguments related to benchmark.
+
+ --batch-size BATCH_SIZE
+ batch size.
+ --max-iter MAX_ITER train max steps.
+ --run-benchmark RUN_BENCHMARK
+ runing benchmark or not, if True, use the --batch-size
+ and --max-iter.
+ --profiler_options PROFILER_OPTIONS
+ The option of profiler, which should be in format
+ "key1=value1;key2=value2;key3=value3".
+```
+
+1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
+2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
+3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
+4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
+
+### Synthesizing
+`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
+```bash
+CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
+```
+```text
+usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
+ [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
+ [--output-dir OUTPUT_DIR] [--ngpu NGPU]
+
+Synthesize with GANVocoder.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --generator-type GENERATOR_TYPE
+ type of GANVocoder, should in {pwgan, mb_melgan,
+ style_melgan, } now
+ --config CONFIG GANVocoder config file.
+ --checkpoint CHECKPOINT
+ snapshot to load.
+ --test-metadata TEST_METADATA
+ dev data.
+ --output-dir OUTPUT_DIR
+ output dir.
+ --ngpu NGPU if ngpu == 0, use cpu.
+```
+
+1. `--config` config file. You should use the same config with which the model is trained.
+2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
+3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
+4. `--output-dir` is the directory to save the synthesized audio files.
+5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
+## Pretrained Models
+The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip).
+
+
+Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
+:-------------:| :------------:| :-----: | :-----: | :--------:
+default| 1(gpu) x 2500000|24.060|0.1068|7.499
+
+HiFiGAN checkpoint contains files listed below.
+
+```text
+hifigan_aishell3_ckpt_0.2.0
+├── default.yaml # default config used to train hifigan
+├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan
+└── snapshot_iter_2500000.pdz # generator parameters of hifigan
+```
+
+## Acknowledgement
+We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
diff --git a/examples/aishell3/voc5/conf/default.yaml b/examples/aishell3/voc5/conf/default.yaml
new file mode 100644
index 000000000..728a90369
--- /dev/null
+++ b/examples/aishell3/voc5/conf/default.yaml
@@ -0,0 +1,168 @@
+# This is the configuration file for AISHELL-3 dataset.
+# This configuration is based on HiFiGAN V1, which is
+# an official configuration. But I found that the optimizer
+# setting does not work well with my implementation.
+# So I changed optimizer settings as follows:
+# - AdamW -> Adam
+# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
+# - Scheduler: ExponentialLR -> MultiStepLR
+# To match the shift size difference, the upsample scales
+# is also modified from the original 256 shift setting.
+###########################################################
+# FEATURE EXTRACTION SETTING #
+###########################################################
+fs: 24000 # Sampling rate.
+n_fft: 2048 # FFT size (samples).
+n_shift: 300 # Hop size (samples). 12.5ms
+win_length: 1200 # Window length (samples). 50ms
+ # If set to null, it will be the same as fft_size.
+window: "hann" # Window function.
+n_mels: 80 # Number of mel basis.
+fmin: 80 # Minimum freq in mel basis calculation. (Hz)
+fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
+
+###########################################################
+# GENERATOR NETWORK ARCHITECTURE SETTING #
+###########################################################
+generator_params:
+ in_channels: 80 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ channels: 512 # Number of initial channels.
+ kernel_size: 7 # Kernel size of initial and final conv layers.
+ upsample_scales: [5, 5, 4, 3] # Upsampling scales.
+ upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers.
+ resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
+ resblock_dilations: # Dilations for residual blocks.
+ - [1, 3, 5]
+ - [1, 3, 5]
+ - [1, 3, 5]
+ use_additional_convs: True # Whether to use additional conv layer in residual blocks.
+ bias: True # Whether to use bias parameter in conv.
+ nonlinear_activation: "leakyrelu" # Nonlinear activation type.
+ nonlinear_activation_params: # Nonlinear activation paramters.
+ negative_slope: 0.1
+ use_weight_norm: True # Whether to apply weight normalization.
+
+
+###########################################################
+# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
+###########################################################
+discriminator_params:
+ scales: 3 # Number of multi-scale discriminator.
+ scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
+ scale_downsample_pooling_params:
+ kernel_size: 4 # Pooling kernel size.
+ stride: 2 # Pooling stride.
+ padding: 2 # Padding size.
+ scale_discriminator_params:
+ in_channels: 1 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
+ channels: 128 # Initial number of channels.
+ max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
+ max_groups: 16 # Maximum number of groups in downsampling conv layers.
+ bias: True
+ downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
+ nonlinear_activation: "leakyrelu" # Nonlinear activation.
+ nonlinear_activation_params:
+ negative_slope: 0.1
+ follow_official_norm: True # Whether to follow the official norm setting.
+ periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
+ period_discriminator_params:
+ in_channels: 1 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ kernel_sizes: [5, 3] # List of kernel sizes.
+ channels: 32 # Initial number of channels.
+ downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
+ max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
+ bias: True # Whether to use bias parameter in conv layer."
+ nonlinear_activation: "leakyrelu" # Nonlinear activation.
+ nonlinear_activation_params: # Nonlinear activation paramters.
+ negative_slope: 0.1
+ use_weight_norm: True # Whether to apply weight normalization.
+ use_spectral_norm: False # Whether to apply spectral normalization.
+
+
+###########################################################
+# STFT LOSS SETTING #
+###########################################################
+use_stft_loss: False # Whether to use multi-resolution STFT loss.
+use_mel_loss: True # Whether to use Mel-spectrogram loss.
+mel_loss_params:
+ fs: 24000
+ fft_size: 2048
+ hop_size: 300
+ win_length: 1200
+ window: "hann"
+ num_mels: 80
+ fmin: 0
+ fmax: 12000
+ log_base: null
+generator_adv_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+discriminator_adv_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+use_feat_match_loss: True
+feat_match_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+ average_by_layers: False # Whether to average loss by #layers in each discriminator.
+ include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
+
+###########################################################
+# ADVERSARIAL LOSS SETTING #
+###########################################################
+lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
+lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
+lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
+
+###########################################################
+# DATA LOADER SETTING #
+###########################################################
+batch_size: 16 # Batch size.
+batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size.
+num_workers: 2 # Number of workers in DataLoader.
+
+###########################################################
+# OPTIMIZER & SCHEDULER SETTING #
+###########################################################
+generator_optimizer_params:
+ beta1: 0.5
+ beta2: 0.9
+ weight_decay: 0.0 # Generator's weight decay coefficient.
+generator_scheduler_params:
+ learning_rate: 2.0e-4 # Generator's learning rate.
+ gamma: 0.5 # Generator's scheduler gamma.
+ milestones: # At each milestone, lr will be multiplied by gamma.
+ - 200000
+ - 400000
+ - 600000
+ - 800000
+generator_grad_norm: -1 # Generator's gradient norm.
+discriminator_optimizer_params:
+ beta1: 0.5
+ beta2: 0.9
+ weight_decay: 0.0 # Discriminator's weight decay coefficient.
+discriminator_scheduler_params:
+ learning_rate: 2.0e-4 # Discriminator's learning rate.
+ gamma: 0.5 # Discriminator's scheduler gamma.
+ milestones: # At each milestone, lr will be multiplied by gamma.
+ - 200000
+ - 400000
+ - 600000
+ - 800000
+discriminator_grad_norm: -1 # Discriminator's gradient norm.
+
+###########################################################
+# INTERVAL SETTING #
+###########################################################
+generator_train_start_steps: 1 # Number of steps to start to train discriminator.
+discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
+train_max_steps: 2500000 # Number of training steps.
+save_interval_steps: 5000 # Interval steps to save checkpoint.
+eval_interval_steps: 1000 # Interval steps to evaluate the network.
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/aishell3/voc5/local/preprocess.sh b/examples/aishell3/voc5/local/preprocess.sh
new file mode 100755
index 000000000..44cc3dbe4
--- /dev/null
+++ b/examples/aishell3/voc5/local/preprocess.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+stage=0
+stop_stage=100
+
+config_path=$1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # get durations from MFA's result
+ echo "Generate durations.txt from MFA results ..."
+ python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
+ --inputdir=./aishell3_alignment_tone \
+ --output=durations.txt \
+ --config=${config_path}
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # extract features
+ echo "Extract features ..."
+ python3 ${BIN_DIR}/../preprocess.py \
+ --rootdir=~/datasets/data_aishell3/ \
+ --dataset=aishell3 \
+ --dumpdir=dump \
+ --dur-file=durations.txt \
+ --config=${config_path} \
+ --cut-sil=True \
+ --num-cpu=20
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # get features' stats(mean and std)
+ echo "Get features' stats ..."
+ python3 ${MAIN_ROOT}/utils/compute_statistics.py \
+ --metadata=dump/train/raw/metadata.jsonl \
+ --field-name="feats"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # normalize, dev and test should use train's stats
+ echo "Normalize ..."
+
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/train/raw/metadata.jsonl \
+ --dumpdir=dump/train/norm \
+ --stats=dump/train/feats_stats.npy
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/dev/raw/metadata.jsonl \
+ --dumpdir=dump/dev/norm \
+ --stats=dump/train/feats_stats.npy
+
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/test/raw/metadata.jsonl \
+ --dumpdir=dump/test/norm \
+ --stats=dump/train/feats_stats.npy
+fi
diff --git a/examples/aishell3/voc5/local/synthesize.sh b/examples/aishell3/voc5/local/synthesize.sh
new file mode 100755
index 000000000..647896175
--- /dev/null
+++ b/examples/aishell3/voc5/local/synthesize.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+
+FLAGS_allocator_strategy=naive_best_fit \
+FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+python3 ${BIN_DIR}/../synthesize.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --test-metadata=dump/test/norm/metadata.jsonl \
+ --output-dir=${train_output_path}/test \
+ --generator-type=hifigan
diff --git a/examples/aishell3/voc5/local/train.sh b/examples/aishell3/voc5/local/train.sh
new file mode 100755
index 000000000..9695631ef
--- /dev/null
+++ b/examples/aishell3/voc5/local/train.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+
+FLAGS_cudnn_exhaustive_search=true \
+FLAGS_conv_workspace_size_limit=4000 \
+python ${BIN_DIR}/train.py \
+ --train-metadata=dump/train/norm/metadata.jsonl \
+ --dev-metadata=dump/dev/norm/metadata.jsonl \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/aishell3/voc5/path.sh b/examples/aishell3/voc5/path.sh
new file mode 100755
index 000000000..7451b3218
--- /dev/null
+++ b/examples/aishell3/voc5/path.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+export MAIN_ROOT=`realpath ${PWD}/../../../`
+
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
+export LC_ALL=C
+
+export PYTHONDONTWRITEBYTECODE=1
+# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
+
+MODEL=hifigan
+export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
diff --git a/examples/aishell3/voc5/run.sh b/examples/aishell3/voc5/run.sh
new file mode 100755
index 000000000..4f426ea02
--- /dev/null
+++ b/examples/aishell3/voc5/run.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+set -e
+source path.sh
+
+gpus=0
+stage=0
+stop_stage=100
+
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_5000.pdz
+
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ ./local/preprocess.sh ${conf_path} || exit -1
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # synthesize
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
+fi
diff --git a/examples/csmsc/tts0/local/synthesize_e2e.sh b/examples/csmsc/tts0/local/synthesize_e2e.sh
index f76758733..4c3b08dc1 100755
--- a/examples/csmsc/tts0/local/synthesize_e2e.sh
+++ b/examples/csmsc/tts0/local/synthesize_e2e.sh
@@ -7,7 +7,7 @@ ckpt_name=$3
stage=0
stop_stage=0
-# TODO: tacotron2 动转静的结果没有静态图的响亮, 可能还是 decode 的时候某个函数动静不对齐
+# TODO: tacotron2 动转静的结果没有动态图的响亮, 可能还是 decode 的时候某个函数动静不对齐
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
diff --git a/examples/csmsc/tts2/local/synthesize.sh b/examples/csmsc/tts2/local/synthesize.sh
index 37b298183..b8982a16d 100755
--- a/examples/csmsc/tts2/local/synthesize.sh
+++ b/examples/csmsc/tts2/local/synthesize.sh
@@ -14,7 +14,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--am=speedyspeech_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
+ --am_stat=dump/train/feats_stats.npy \
--voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
@@ -34,7 +34,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--am=speedyspeech_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
+ --am_stat=dump/train/feats_stats.npy \
--voc=mb_melgan_csmsc \
--voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
@@ -53,7 +53,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--am=speedyspeech_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
+ --am_stat=dump/train/feats_stats.npy \
--voc=style_melgan_csmsc \
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
@@ -73,7 +73,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--am=speedyspeech_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
+ --am_stat=dump/train/feats_stats.npy \
--voc=hifigan_csmsc \
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
@@ -93,7 +93,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--am=speedyspeech_csmsc \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
+ --am_stat=dump/train/feats_stats.npy \
--voc=wavernn_csmsc \
--voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \
--voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \
diff --git a/examples/ljspeech/voc5/README.md b/examples/ljspeech/voc5/README.md
new file mode 100644
index 000000000..210829428
--- /dev/null
+++ b/examples/ljspeech/voc5/README.md
@@ -0,0 +1,133 @@
+# HiFiGAN with the LJSpeech-1.1
+This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
+## Dataset
+### Download and Extract
+Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/).
+### Get MFA Result and Extract
+We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
+You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
+
+## Get Started
+Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
+Assume the path to the MFA result of LJSpeech-1.1 is `./ljspeech_alignment`.
+Run the command below to
+1. **source path**.
+2. preprocess the dataset.
+3. train the model.
+4. synthesize wavs.
+ - synthesize waveform from `metadata.jsonl`.
+```bash
+./run.sh
+```
+You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
+```bash
+./run.sh --stage 0 --stop-stage 0
+```
+### Data Preprocessing
+```bash
+./local/preprocess.sh ${conf_path}
+```
+When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
+
+```text
+dump
+├── dev
+│ ├── norm
+│ └── raw
+├── test
+│ ├── norm
+│ └── raw
+└── train
+ ├── norm
+ ├── raw
+ └── feats_stats.npy
+```
+
+The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
+
+Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
+
+### Model Training
+`./local/train.sh` calls `${BIN_DIR}/train.py`.
+```bash
+CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
+```
+Here's the complete help message.
+
+```text
+usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
+ [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
+ [--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
+ [--run-benchmark RUN_BENCHMARK]
+ [--profiler_options PROFILER_OPTIONS]
+
+Train a ParallelWaveGAN model.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --config CONFIG config file to overwrite default config.
+ --train-metadata TRAIN_METADATA
+ training data.
+ --dev-metadata DEV_METADATA
+ dev data.
+ --output-dir OUTPUT_DIR
+ output dir.
+ --ngpu NGPU if ngpu == 0, use cpu.
+
+benchmark:
+ arguments related to benchmark.
+
+ --batch-size BATCH_SIZE
+ batch size.
+ --max-iter MAX_ITER train max steps.
+ --run-benchmark RUN_BENCHMARK
+ runing benchmark or not, if True, use the --batch-size
+ and --max-iter.
+ --profiler_options PROFILER_OPTIONS
+ The option of profiler, which should be in format
+ "key1=value1;key2=value2;key3=value3".
+```
+
+1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
+2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
+3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
+4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
+
+### Synthesizing
+`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
+```bash
+CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
+```
+```text
+usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
+ [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
+ [--output-dir OUTPUT_DIR] [--ngpu NGPU]
+
+Synthesize with GANVocoder.
+
+optional arguments:
+ -h, --help show this help message and exit
+ --generator-type GENERATOR_TYPE
+ type of GANVocoder, should in {pwgan, mb_melgan,
+ style_melgan, } now
+ --config CONFIG GANVocoder config file.
+ --checkpoint CHECKPOINT
+ snapshot to load.
+ --test-metadata TEST_METADATA
+ dev data.
+ --output-dir OUTPUT_DIR
+ output dir.
+ --ngpu NGPU if ngpu == 0, use cpu.
+```
+
+1. `--config` parallel wavegan config file. You should use the same config with which the model is trained.
+2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
+3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
+4. `--output-dir` is the directory to save the synthesized audio files.
+5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
+
+## Pretrained Model
+
+
+## Acknowledgement
+We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
diff --git a/examples/ljspeech/voc5/conf/default.yaml b/examples/ljspeech/voc5/conf/default.yaml
new file mode 100644
index 000000000..97c512204
--- /dev/null
+++ b/examples/ljspeech/voc5/conf/default.yaml
@@ -0,0 +1,167 @@
+# This is the configuration file for LJSpeech dataset.
+# This configuration is based on HiFiGAN V1, which is an official configuration.
+# But I found that the optimizer setting does not work well with my implementation.
+# So I changed optimizer settings as follows:
+# - AdamW -> Adam
+# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
+# - Scheduler: ExponentialLR -> MultiStepLR
+# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting.
+
+###########################################################
+# FEATURE EXTRACTION SETTING #
+###########################################################
+fs: 22050 # Sampling rate.
+n_fft: 1024 # FFT size (samples).
+n_shift: 256 # Hop size (samples). 11.6ms
+win_length: null # Window length (samples).
+ # If set to null, it will be the same as fft_size.
+window: "hann" # Window function.
+n_mels: 80 # Number of mel basis.
+fmin: 80 # Minimum freq in mel basis calculation. (Hz)
+fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
+
+###########################################################
+# GENERATOR NETWORK ARCHITECTURE SETTING #
+###########################################################
+generator_params:
+ in_channels: 80 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ channels: 512 # Number of initial channels.
+ kernel_size: 7 # Kernel size of initial and final conv layers.
+ upsample_scales: [8, 8, 2, 2] # Upsampling scales.
+ upsample_kernel_sizes: [16, 16, 4, 4] # Kernel size for upsampling layers.
+ resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
+ resblock_dilations: # Dilations for residual blocks.
+ - [1, 3, 5]
+ - [1, 3, 5]
+ - [1, 3, 5]
+ use_additional_convs: True # Whether to use additional conv layer in residual blocks.
+ bias: True # Whether to use bias parameter in conv.
+ nonlinear_activation: "leakyrelu" # Nonlinear activation type.
+ nonlinear_activation_params: # Nonlinear activation paramters.
+ negative_slope: 0.1
+ use_weight_norm: True # Whether to apply weight normalization.
+
+
+###########################################################
+# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
+###########################################################
+discriminator_params:
+ scales: 3 # Number of multi-scale discriminator.
+ scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
+ scale_downsample_pooling_params:
+ kernel_size: 4 # Pooling kernel size.
+ stride: 2 # Pooling stride.
+ padding: 2 # Padding size.
+ scale_discriminator_params:
+ in_channels: 1 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
+ channels: 128 # Initial number of channels.
+ max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
+ max_groups: 16 # Maximum number of groups in downsampling conv layers.
+ bias: True
+ downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
+ nonlinear_activation: "leakyrelu" # Nonlinear activation.
+ nonlinear_activation_params:
+ negative_slope: 0.1
+ follow_official_norm: True # Whether to follow the official norm setting.
+ periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
+ period_discriminator_params:
+ in_channels: 1 # Number of input channels.
+ out_channels: 1 # Number of output channels.
+ kernel_sizes: [5, 3] # List of kernel sizes.
+ channels: 32 # Initial number of channels.
+ downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
+ max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
+ bias: True # Whether to use bias parameter in conv layer."
+ nonlinear_activation: "leakyrelu" # Nonlinear activation.
+ nonlinear_activation_params: # Nonlinear activation paramters.
+ negative_slope: 0.1
+ use_weight_norm: True # Whether to apply weight normalization.
+ use_spectral_norm: False # Whether to apply spectral normalization.
+
+
+###########################################################
+# STFT LOSS SETTING #
+###########################################################
+use_stft_loss: False # Whether to use multi-resolution STFT loss.
+use_mel_loss: True # Whether to use Mel-spectrogram loss.
+mel_loss_params:
+ fs: 22050
+ fft_size: 1024
+ hop_size: 256
+ win_length: null
+ window: "hann"
+ num_mels: 80
+ fmin: 0
+ fmax: 11025
+ log_base: null
+generator_adv_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+discriminator_adv_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+use_feat_match_loss: True
+feat_match_loss_params:
+ average_by_discriminators: False # Whether to average loss by #discriminators.
+ average_by_layers: False # Whether to average loss by #layers in each discriminator.
+ include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
+
+###########################################################
+# ADVERSARIAL LOSS SETTING #
+###########################################################
+lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
+lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
+lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
+
+###########################################################
+# DATA LOADER SETTING #
+###########################################################
+batch_size: 16 # Batch size.
+batch_max_steps: 8192 # Length of each audio in batch. Make sure dividable by hop_size.
+num_workers: 2 # Number of workers in DataLoader.
+
+###########################################################
+# OPTIMIZER & SCHEDULER SETTING #
+###########################################################
+generator_optimizer_params:
+ beta1: 0.5
+ beta2: 0.9
+ weight_decay: 0.0 # Generator's weight decay coefficient.
+generator_scheduler_params:
+ learning_rate: 2.0e-4 # Generator's learning rate.
+ gamma: 0.5 # Generator's scheduler gamma.
+ milestones: # At each milestone, lr will be multiplied by gamma.
+ - 200000
+ - 400000
+ - 600000
+ - 800000
+generator_grad_norm: -1 # Generator's gradient norm.
+discriminator_optimizer_params:
+ beta1: 0.5
+ beta2: 0.9
+ weight_decay: 0.0 # Discriminator's weight decay coefficient.
+discriminator_scheduler_params:
+ learning_rate: 2.0e-4 # Discriminator's learning rate.
+ gamma: 0.5 # Discriminator's scheduler gamma.
+ milestones: # At each milestone, lr will be multiplied by gamma.
+ - 200000
+ - 400000
+ - 600000
+ - 800000
+discriminator_grad_norm: -1 # Discriminator's gradient norm.
+
+###########################################################
+# INTERVAL SETTING #
+###########################################################
+generator_train_start_steps: 1 # Number of steps to start to train discriminator.
+discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
+train_max_steps: 2500000 # Number of training steps.
+save_interval_steps: 5000 # Interval steps to save checkpoint.
+eval_interval_steps: 1000 # Interval steps to evaluate the network.
+
+###########################################################
+# OTHER SETTING #
+###########################################################
+num_snapshots: 10 # max number of snapshots to keep while training
+seed: 42 # random seed for paddle, random, and np.random
diff --git a/examples/ljspeech/voc5/local/preprocess.sh b/examples/ljspeech/voc5/local/preprocess.sh
new file mode 100755
index 000000000..d1af60dad
--- /dev/null
+++ b/examples/ljspeech/voc5/local/preprocess.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+stage=0
+stop_stage=100
+
+config_path=$1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # get durations from MFA's result
+ echo "Generate durations.txt from MFA results ..."
+ python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
+ --inputdir=./ljspeech_alignment \
+ --output=durations.txt \
+ --config=${config_path}
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # extract features
+ echo "Extract features ..."
+ python3 ${BIN_DIR}/../preprocess.py \
+ --rootdir=~/datasets/LJSpeech-1.1/ \
+ --dataset=ljspeech \
+ --dumpdir=dump \
+ --dur-file=durations.txt \
+ --config=${config_path} \
+ --cut-sil=True \
+ --num-cpu=20
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # get features' stats(mean and std)
+ echo "Get features' stats ..."
+ python3 ${MAIN_ROOT}/utils/compute_statistics.py \
+ --metadata=dump/train/raw/metadata.jsonl \
+ --field-name="feats"
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ # normalize, dev and test should use train's stats
+ echo "Normalize ..."
+
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/train/raw/metadata.jsonl \
+ --dumpdir=dump/train/norm \
+ --stats=dump/train/feats_stats.npy
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/dev/raw/metadata.jsonl \
+ --dumpdir=dump/dev/norm \
+ --stats=dump/train/feats_stats.npy
+
+ python3 ${BIN_DIR}/../normalize.py \
+ --metadata=dump/test/raw/metadata.jsonl \
+ --dumpdir=dump/test/norm \
+ --stats=dump/train/feats_stats.npy
+fi
diff --git a/examples/ljspeech/voc5/local/synthesize.sh b/examples/ljspeech/voc5/local/synthesize.sh
new file mode 100755
index 000000000..647896175
--- /dev/null
+++ b/examples/ljspeech/voc5/local/synthesize.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+ckpt_name=$3
+
+FLAGS_allocator_strategy=naive_best_fit \
+FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+python3 ${BIN_DIR}/../synthesize.py \
+ --config=${config_path} \
+ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --test-metadata=dump/test/norm/metadata.jsonl \
+ --output-dir=${train_output_path}/test \
+ --generator-type=hifigan
diff --git a/examples/ljspeech/voc5/local/train.sh b/examples/ljspeech/voc5/local/train.sh
new file mode 100755
index 000000000..9695631ef
--- /dev/null
+++ b/examples/ljspeech/voc5/local/train.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+config_path=$1
+train_output_path=$2
+
+FLAGS_cudnn_exhaustive_search=true \
+FLAGS_conv_workspace_size_limit=4000 \
+python ${BIN_DIR}/train.py \
+ --train-metadata=dump/train/norm/metadata.jsonl \
+ --dev-metadata=dump/dev/norm/metadata.jsonl \
+ --config=${config_path} \
+ --output-dir=${train_output_path} \
+ --ngpu=1
diff --git a/examples/ljspeech/voc5/path.sh b/examples/ljspeech/voc5/path.sh
new file mode 100755
index 000000000..7451b3218
--- /dev/null
+++ b/examples/ljspeech/voc5/path.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+export MAIN_ROOT=`realpath ${PWD}/../../../`
+
+export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
+export LC_ALL=C
+
+export PYTHONDONTWRITEBYTECODE=1
+# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
+
+MODEL=hifigan
+export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
diff --git a/examples/ljspeech/voc5/run.sh b/examples/ljspeech/voc5/run.sh
new file mode 100755
index 000000000..cab1ac38b
--- /dev/null
+++ b/examples/ljspeech/voc5/run.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+set -e
+source path.sh
+
+gpus=0,1
+stage=0
+stop_stage=100
+
+conf_path=conf/default.yaml
+train_output_path=exp/default
+ckpt_name=snapshot_iter_5000.pdz
+
+# with the following command, you can choose the stage range you want to run
+# such as `./run.sh --stage 0 --stop-stage 0`
+# this can not be mixed use with `$1`, `$2` ...
+source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # prepare data
+ ./local/preprocess.sh ${conf_path} || exit -1
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ # train model, all `ckpt` under `train_output_path/checkpoints/` dir
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ # synthesize
+ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
+fi
diff --git a/paddleaudio/CHANGELOG.md b/paddleaudio/CHANGELOG.md
index 825c32f0d..925d77696 100644
--- a/paddleaudio/CHANGELOG.md
+++ b/paddleaudio/CHANGELOG.md
@@ -1 +1,9 @@
# Changelog
+
+Date: 2022-3-15, Author: Xiaojie Chen.
+ - kaldi and librosa mfcc, fbank, spectrogram.
+ - unit test and benchmark.
+
+Date: 2022-2-25, Author: Hui Zhang.
+ - Refactor architecture.
+ - dtw distance and mcd style dtw.
diff --git a/paddleaudio/features/augment.py b/paddleaudio/features/augment.py
deleted file mode 100644
index 6f903bdba..000000000
--- a/paddleaudio/features/augment.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import List
-
-import numpy as np
-from numpy import ndarray as array
-
-from ..backends import depth_convert
-from ..utils import ParameterError
-
-__all__ = [
- 'depth_augment',
- 'spect_augment',
- 'random_crop1d',
- 'random_crop2d',
- 'adaptive_spect_augment',
-]
-
-
-def randint(high: int) -> int:
- """Generate one random integer in range [0 high)
-
- This is a helper function for random data augmentaiton
- """
- return int(np.random.randint(0, high=high))
-
-
-def rand() -> float:
- """Generate one floating-point number in range [0 1)
-
- This is a helper function for random data augmentaiton
- """
- return float(np.random.rand(1))
-
-
-def depth_augment(y: array,
- choices: List=['int8', 'int16'],
- probs: List[float]=[0.5, 0.5]) -> array:
- """ Audio depth augmentation
-
- Do audio depth augmentation to simulate the distortion brought by quantization.
- """
- assert len(probs) == len(
- choices
- ), 'number of choices {} must be equal to size of probs {}'.format(
- len(choices), len(probs))
- depth = np.random.choice(choices, p=probs)
- src_depth = y.dtype
- y1 = depth_convert(y, depth)
- y2 = depth_convert(y1, src_depth)
-
- return y2
-
-
-def adaptive_spect_augment(spect: array, tempo_axis: int=0,
- level: float=0.1) -> array:
- """Do adpative spectrogram augmentation
-
- The level of the augmentation is gowern by the paramter level,
- ranging from 0 to 1, with 0 represents no augmentation。
-
- """
- assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
- if tempo_axis == 0:
- nt, nf = spect.shape
- else:
- nf, nt = spect.shape
-
- time_mask_width = int(nt * level * 0.5)
- freq_mask_width = int(nf * level * 0.5)
-
- num_time_mask = int(10 * level)
- num_freq_mask = int(10 * level)
-
- if tempo_axis == 0:
- for _ in range(num_time_mask):
- start = randint(nt - time_mask_width)
- spect[start:start + time_mask_width, :] = 0
- for _ in range(num_freq_mask):
- start = randint(nf - freq_mask_width)
- spect[:, start:start + freq_mask_width] = 0
- else:
- for _ in range(num_time_mask):
- start = randint(nt - time_mask_width)
- spect[:, start:start + time_mask_width] = 0
- for _ in range(num_freq_mask):
- start = randint(nf - freq_mask_width)
- spect[start:start + freq_mask_width, :] = 0
-
- return spect
-
-
-def spect_augment(spect: array,
- tempo_axis: int=0,
- max_time_mask: int=3,
- max_freq_mask: int=3,
- max_time_mask_width: int=30,
- max_freq_mask_width: int=20) -> array:
- """Do spectrogram augmentation in both time and freq axis
-
- Reference:
-
- """
- assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
- if tempo_axis == 0:
- nt, nf = spect.shape
- else:
- nf, nt = spect.shape
-
- num_time_mask = randint(max_time_mask)
- num_freq_mask = randint(max_freq_mask)
-
- time_mask_width = randint(max_time_mask_width)
- freq_mask_width = randint(max_freq_mask_width)
-
- if tempo_axis == 0:
- for _ in range(num_time_mask):
- start = randint(nt - time_mask_width)
- spect[start:start + time_mask_width, :] = 0
- for _ in range(num_freq_mask):
- start = randint(nf - freq_mask_width)
- spect[:, start:start + freq_mask_width] = 0
- else:
- for _ in range(num_time_mask):
- start = randint(nt - time_mask_width)
- spect[:, start:start + time_mask_width] = 0
- for _ in range(num_freq_mask):
- start = randint(nf - freq_mask_width)
- spect[start:start + freq_mask_width, :] = 0
-
- return spect
-
-
-def random_crop1d(y: array, crop_len: int) -> array:
- """ Do random cropping on 1d input signal
-
- The input is a 1d signal, typically a sound waveform
- """
- if y.ndim != 1:
- 'only accept 1d tensor or numpy array'
- n = len(y)
- idx = randint(n - crop_len)
- return y[idx:idx + crop_len]
-
-
-def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
- """ Do random cropping for 2D array, typically a spectrogram.
-
- The cropping is done in temporal direction on the time-freq input signal.
- """
- if tempo_axis >= s.ndim:
- raise ParameterError('axis out of range')
-
- n = s.shape[tempo_axis]
- idx = randint(high=n - crop_len)
- sli = [slice(None) for i in range(s.ndim)]
- sli[tempo_axis] = slice(idx, idx + crop_len)
- out = s[tuple(sli)]
- return out
diff --git a/paddleaudio/features/spectrum.py b/paddleaudio/features/spectrum.py
deleted file mode 100644
index 154b6484c..000000000
--- a/paddleaudio/features/spectrum.py
+++ /dev/null
@@ -1,461 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from functools import partial
-from typing import Optional
-from typing import Union
-
-import paddle
-import paddle.nn as nn
-
-from .window import get_window
-
-__all__ = [
- 'Spectrogram',
- 'MelSpectrogram',
- 'LogMelSpectrogram',
-]
-
-
-def hz_to_mel(freq: Union[paddle.Tensor, float],
- htk: bool=False) -> Union[paddle.Tensor, float]:
- """Convert Hz to Mels.
- Parameters:
- freq: the input tensor of arbitrary shape, or a single floating point number.
- htk: use HTK formula to do the conversion.
- The default value is False.
- Returns:
- The frequencies represented in Mel-scale.
- """
-
- if htk:
- if isinstance(freq, paddle.Tensor):
- return 2595.0 * paddle.log10(1.0 + freq / 700.0)
- else:
- return 2595.0 * math.log10(1.0 + freq / 700.0)
-
- # Fill in the linear part
- f_min = 0.0
- f_sp = 200.0 / 3
-
- mels = (freq - f_min) / f_sp
-
- # Fill in the log-scale part
-
- min_log_hz = 1000.0 # beginning of log region (Hz)
- min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
- logstep = math.log(6.4) / 27.0 # step size for log region
-
- if isinstance(freq, paddle.Tensor):
- target = min_log_mel + paddle.log(
- freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
- mask = (freq > min_log_hz).astype(freq.dtype)
- mels = target * mask + mels * (
- 1 - mask) # will replace by masked_fill OP in future
- else:
- if freq >= min_log_hz:
- mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
-
- return mels
-
-
-def mel_to_hz(mel: Union[float, paddle.Tensor],
- htk: bool=False) -> Union[float, paddle.Tensor]:
- """Convert mel bin numbers to frequencies.
- Parameters:
- mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
- htk: use HTK formula to do the conversion.
- Returns:
- The frequencies represented in hz.
- """
- if htk:
- return 700.0 * (10.0**(mel / 2595.0) - 1.0)
-
- f_min = 0.0
- f_sp = 200.0 / 3
- freqs = f_min + f_sp * mel
- # And now the nonlinear scale
- min_log_hz = 1000.0 # beginning of log region (Hz)
- min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
- logstep = math.log(6.4) / 27.0 # step size for log region
- if isinstance(mel, paddle.Tensor):
- target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
- mask = (mel > min_log_mel).astype(mel.dtype)
- freqs = target * mask + freqs * (
- 1 - mask) # will replace by masked_fill OP in future
- else:
- if mel >= min_log_mel:
- freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
-
- return freqs
-
-
-def mel_frequencies(n_mels: int=64,
- f_min: float=0.0,
- f_max: float=11025.0,
- htk: bool=False,
- dtype: str=paddle.float32):
- """Compute mel frequencies.
- Parameters:
- n_mels(int): number of Mel bins.
- f_min(float): the lower cut-off frequency, below which the filter response is zero.
- f_max(float): the upper cut-off frequency, above which the filter response is zero.
- htk(bool): whether to use htk formula.
- dtype(str): the datatype of the return frequencies.
- Returns:
- The frequencies represented in Mel-scale
- """
- # 'Center freqs' of mel bands - uniformly spaced between limits
- min_mel = hz_to_mel(f_min, htk=htk)
- max_mel = hz_to_mel(f_max, htk=htk)
- mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
- freqs = mel_to_hz(mels, htk=htk)
- return freqs
-
-
-def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
- """Compute fourier frequencies.
- Parameters:
- sr(int): the audio sample rate.
- n_fft(float): the number of fft bins.
- dtype(str): the datatype of the return frequencies.
- Returns:
- The frequencies represented in hz.
- """
- return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
-
-
-def compute_fbank_matrix(sr: int,
- n_fft: int,
- n_mels: int=64,
- f_min: float=0.0,
- f_max: Optional[float]=None,
- htk: bool=False,
- norm: Union[str, float]='slaney',
- dtype: str=paddle.float32):
- """Compute fbank matrix.
- Parameters:
- sr(int): the audio sample rate.
- n_fft(int): the number of fft bins.
- n_mels(int): the number of Mel bins.
- f_min(float): the lower cut-off frequency, below which the filter response is zero.
- f_max(float): the upper cut-off frequency, above which the filter response is zero.
- htk: whether to use htk formula.
- return_complex(bool): whether to return complex matrix. If True, the matrix will
- be complex type. Otherwise, the real and image part will be stored in the last
- axis of returned tensor.
- dtype(str): the datatype of the returned fbank matrix.
- Returns:
- The fbank matrix of shape (n_mels, int(1+n_fft//2)).
- Shape:
- output: (n_mels, int(1+n_fft//2))
- """
-
- if f_max is None:
- f_max = float(sr) / 2
-
- # Initialize the weights
- weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
-
- # Center freqs of each FFT bin
- fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
-
- # 'Center freqs' of mel bands - uniformly spaced between limits
- mel_f = mel_frequencies(
- n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
-
- fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
- ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
- #ramps = np.subtract.outer(mel_f, fftfreqs)
-
- for i in range(n_mels):
- # lower and upper slopes for all bins
- lower = -ramps[i] / fdiff[i]
- upper = ramps[i + 2] / fdiff[i + 1]
-
- # .. then intersect them with each other and zero
- weights[i] = paddle.maximum(
- paddle.zeros_like(lower), paddle.minimum(lower, upper))
-
- # Slaney-style mel is scaled to be approx constant energy per channel
- if norm == 'slaney':
- enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
- weights *= enorm.unsqueeze(1)
- elif isinstance(norm, int) or isinstance(norm, float):
- weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
-
- return weights
-
-
-def power_to_db(magnitude: paddle.Tensor,
- ref_value: float=1.0,
- amin: float=1e-10,
- top_db: Optional[float]=None) -> paddle.Tensor:
- """Convert a power spectrogram (amplitude squared) to decibel (dB) units.
- The function computes the scaling ``10 * log10(x / ref)`` in a numerically
- stable way.
- Parameters:
- magnitude(Tensor): the input magnitude tensor of any shape.
- ref_value(float): the reference value. If smaller than 1.0, the db level
- of the signal will be pulled up accordingly. Otherwise, the db level
- is pushed down.
- amin(float): the minimum value of input magnitude, below which the input
- magnitude is clipped(to amin).
- top_db(float): the maximum db value of resulting spectrum, above which the
- spectrum is clipped(to top_db).
- Returns:
- The spectrogram in log-scale.
- shape:
- input: any shape
- output: same as input
- """
- if amin <= 0:
- raise Exception("amin must be strictly positive")
-
- if ref_value <= 0:
- raise Exception("ref_value must be strictly positive")
-
- ones = paddle.ones_like(magnitude)
- log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
- log_spec -= 10.0 * math.log10(max(ref_value, amin))
-
- if top_db is not None:
- if top_db < 0:
- raise Exception("top_db must be non-negative")
- log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
-
- return log_spec
-
-
-class Spectrogram(nn.Layer):
- def __init__(self,
- n_fft: int=512,
- hop_length: Optional[int]=None,
- win_length: Optional[int]=None,
- window: str='hann',
- center: bool=True,
- pad_mode: str='reflect',
- dtype: str=paddle.float32):
- """Compute spectrogram of a given signal, typically an audio waveform.
- The spectorgram is defined as the complex norm of the short-time
- Fourier transformation.
- Parameters:
- n_fft(int): the number of frequency components of the discrete Fourier transform.
- The default value is 2048,
- hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
- The default value is None.
- win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
- The default value is None.
- window(str): the name of the window function applied to the single before the Fourier transform.
- The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
- 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
- The default value is 'hann'
- center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
- If False, frame t begins at x[t * hop_length]
- The default value is True
- pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
- and 'constant'. The default value is 'reflect'.
- dtype(str): the data type of input and window.
- Notes:
- The Spectrogram transform relies on STFT transform to compute the spectrogram.
- By default, the weights are not learnable. To fine-tune the Fourier coefficients,
- set stop_gradient=False before training.
- For more information, see STFT().
- """
- super(Spectrogram, self).__init__()
-
- if win_length is None:
- win_length = n_fft
-
- fft_window = get_window(window, win_length, fftbins=True, dtype=dtype)
- self._stft = partial(
- paddle.signal.stft,
- n_fft=n_fft,
- hop_length=hop_length,
- win_length=win_length,
- window=fft_window,
- center=center,
- pad_mode=pad_mode)
-
- def forward(self, x):
- stft = self._stft(x)
- spectrogram = paddle.square(paddle.abs(stft))
- return spectrogram
-
-
-class MelSpectrogram(nn.Layer):
- def __init__(self,
- sr: int=22050,
- n_fft: int=512,
- hop_length: Optional[int]=None,
- win_length: Optional[int]=None,
- window: str='hann',
- center: bool=True,
- pad_mode: str='reflect',
- n_mels: int=64,
- f_min: float=50.0,
- f_max: Optional[float]=None,
- htk: bool=False,
- norm: Union[str, float]='slaney',
- dtype: str=paddle.float32):
- """Compute the melspectrogram of a given signal, typically an audio waveform.
- The melspectrogram is also known as filterbank or fbank feature in audio community.
- It is computed by multiplying spectrogram with Mel filter bank matrix.
- Parameters:
- sr(int): the audio sample rate.
- The default value is 22050.
- n_fft(int): the number of frequency components of the discrete Fourier transform.
- The default value is 2048,
- hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
- The default value is None.
- win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
- The default value is None.
- window(str): the name of the window function applied to the single before the Fourier transform.
- The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
- 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
- The default value is 'hann'
- center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
- If False, frame t begins at x[t * hop_length]
- The default value is True
- pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
- and 'constant'.
- The default value is 'reflect'.
- n_mels(int): the mel bins.
- f_min(float): the lower cut-off frequency, below which the filter response is zero.
- f_max(float): the upper cut-off frequency, above which the filter response is zeros.
- htk(bool): whether to use HTK formula in computing fbank matrix.
- norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
- You can specify norm=1.0/2.0 to use customized p-norm normalization.
- dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
- accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
- """
- super(MelSpectrogram, self).__init__()
-
- self._spectrogram = Spectrogram(
- n_fft=n_fft,
- hop_length=hop_length,
- win_length=win_length,
- window=window,
- center=center,
- pad_mode=pad_mode,
- dtype=dtype)
- self.n_mels = n_mels
- self.f_min = f_min
- self.f_max = f_max
- self.htk = htk
- self.norm = norm
- if f_max is None:
- f_max = sr // 2
- self.fbank_matrix = compute_fbank_matrix(
- sr=sr,
- n_fft=n_fft,
- n_mels=n_mels,
- f_min=f_min,
- f_max=f_max,
- htk=htk,
- norm=norm,
- dtype=dtype) # float64 for better numerical results
- self.register_buffer('fbank_matrix', self.fbank_matrix)
-
- def forward(self, x):
- spect_feature = self._spectrogram(x)
- mel_feature = paddle.matmul(self.fbank_matrix, spect_feature)
- return mel_feature
-
-
-class LogMelSpectrogram(nn.Layer):
- def __init__(self,
- sr: int=22050,
- n_fft: int=512,
- hop_length: Optional[int]=None,
- win_length: Optional[int]=None,
- window: str='hann',
- center: bool=True,
- pad_mode: str='reflect',
- n_mels: int=64,
- f_min: float=50.0,
- f_max: Optional[float]=None,
- htk: bool=False,
- norm: Union[str, float]='slaney',
- ref_value: float=1.0,
- amin: float=1e-10,
- top_db: Optional[float]=None,
- dtype: str=paddle.float32):
- """Compute log-mel-spectrogram(also known as LogFBank) feature of a given signal,
- typically an audio waveform.
- Parameters:
- sr(int): the audio sample rate.
- The default value is 22050.
- n_fft(int): the number of frequency components of the discrete Fourier transform.
- The default value is 2048,
- hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
- The default value is None.
- win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
- The default value is None.
- window(str): the name of the window function applied to the single before the Fourier transform.
- The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
- 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
- The default value is 'hann'
- center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
- If False, frame t begins at x[t * hop_length]
- The default value is True
- pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
- and 'constant'.
- The default value is 'reflect'.
- n_mels(int): the mel bins.
- f_min(float): the lower cut-off frequency, below which the filter response is zero.
- f_max(float): the upper cut-off frequency, above which the filter response is zeros.
- ref_value(float): the reference value. If smaller than 1.0, the db level
- htk(bool): whether to use HTK formula in computing fbank matrix.
- norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
- You can specify norm=1.0/2.0 to use customized p-norm normalization.
- dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
- accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
- amin(float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly.
- Otherwise, the db level is pushed down.
- magnitude is clipped(to amin). For numerical stability, set amin to a larger value,
- e.g., 1e-3.
- top_db(float): the maximum db value of resulting spectrum, above which the
- spectrum is clipped(to top_db).
- """
- super(LogMelSpectrogram, self).__init__()
-
- self._melspectrogram = MelSpectrogram(
- sr=sr,
- n_fft=n_fft,
- hop_length=hop_length,
- win_length=win_length,
- window=window,
- center=center,
- pad_mode=pad_mode,
- n_mels=n_mels,
- f_min=f_min,
- f_max=f_max,
- htk=htk,
- norm=norm,
- dtype=dtype)
-
- self.ref_value = ref_value
- self.amin = amin
- self.top_db = top_db
-
- def forward(self, x):
- # import ipdb; ipdb.set_trace()
- mel_feature = self._melspectrogram(x)
- log_mel_feature = power_to_db(
- mel_feature,
- ref_value=self.ref_value,
- amin=self.amin,
- top_db=self.top_db)
- return log_mel_feature
diff --git a/paddleaudio/paddleaudio/__init__.py b/paddleaudio/paddleaudio/__init__.py
new file mode 100644
index 000000000..6184c1dd4
--- /dev/null
+++ b/paddleaudio/paddleaudio/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from . import compliance
+from . import datasets
+from . import features
+from . import functional
+from . import io
+from . import metric
+from . import sox_effects
+from .backends import load
+from .backends import save
diff --git a/paddleaudio/paddleaudio/backends/__init__.py b/paddleaudio/paddleaudio/backends/__init__.py
new file mode 100644
index 000000000..8eae07e82
--- /dev/null
+++ b/paddleaudio/paddleaudio/backends/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .soundfile_backend import depth_convert
+from .soundfile_backend import load
+from .soundfile_backend import normalize
+from .soundfile_backend import resample
+from .soundfile_backend import save
+from .soundfile_backend import to_mono
diff --git a/paddleaudio/backends/audio.py b/paddleaudio/paddleaudio/backends/soundfile_backend.py
similarity index 93%
rename from paddleaudio/backends/audio.py
rename to paddleaudio/paddleaudio/backends/soundfile_backend.py
index 4127570ec..2b920284a 100644
--- a/paddleaudio/backends/audio.py
+++ b/paddleaudio/paddleaudio/backends/soundfile_backend.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ __all__ = [
'to_mono',
'depth_convert',
'normalize',
- 'save_wav',
+ 'save',
'load',
]
NORMALMIZE_TYPES = ['linear', 'gaussian']
@@ -41,12 +41,9 @@ EPS = 1e-8
def resample(y: array, src_sr: int, target_sr: int,
mode: str='kaiser_fast') -> array:
""" Audio resampling
-
This function is the same as using resampy.resample().
-
Notes:
The default mode is kaiser_fast. For better audio quality, use mode = 'kaiser_fast'
-
"""
if mode == 'kaiser_best':
@@ -106,7 +103,6 @@ def to_mono(y: array, merge_type: str='average') -> array:
def _safe_cast(y: array, dtype: Union[type, str]) -> array:
""" data type casting in a safe way, i.e., prevent overflow or underflow
-
This function is used internally.
"""
return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype)
@@ -115,10 +111,8 @@ def _safe_cast(y: array, dtype: Union[type, str]) -> array:
def depth_convert(y: array, dtype: Union[type, str],
dithering: bool=True) -> array:
"""Convert audio array to target dtype safely
-
This function convert audio waveform to a target dtype, with addition steps of
preventing overflow/underflow and preserving audio range.
-
"""
SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64']
@@ -168,12 +162,9 @@ def sound_file_load(file: str,
dtype: str='int16',
duration: Optional[int]=None) -> Tuple[array, int]:
"""Load audio using soundfile library
-
This function load audio file using libsndfile.
-
Reference:
http://www.mega-nerd.com/libsndfile/#Features
-
"""
with sf.SoundFile(file) as sf_desc:
sr_native = sf_desc.samplerate
@@ -188,33 +179,9 @@ def sound_file_load(file: str,
return y, sf_desc.samplerate
-def audio_file_load():
- """Load audio using audiofile library
-
- This function load audio file using audiofile.
-
- Reference:
- https://audiofile.68k.org/
-
- """
- raise NotImplementedError()
-
-
-def sox_file_load():
- """Load audio using sox library
-
- This function load audio file using sox.
-
- Reference:
- http://sox.sourceforge.net/
- """
- raise NotImplementedError()
-
-
def normalize(y: array, norm_type: str='linear',
mul_factor: float=1.0) -> array:
""" normalize an input audio with additional multiplier.
-
"""
if norm_type == 'linear':
@@ -232,14 +199,12 @@ def normalize(y: array, norm_type: str='linear',
return y
-def save_wav(y: array, sr: int, file: str) -> None:
+def save(y: array, sr: int, file: str) -> None:
"""Save audio file to disk.
This function saves audio to disk using scipy.io.wavfile, with additional step
to convert input waveform to int16 unless it already is int16
-
Notes:
It only support raw wav format.
-
"""
if not file.endswith('.wav'):
raise ParameterError(
@@ -274,11 +239,8 @@ def load(
resample_mode: str='kaiser_fast') -> Tuple[array, int]:
"""Load audio file from disk.
This function loads audio from disk using using audio beackend.
-
Parameters:
-
Notes:
-
"""
y, r = sound_file_load(file, offset=offset, dtype=dtype, duration=duration)
diff --git a/paddleaudio/paddleaudio/backends/sox_backend.py b/paddleaudio/paddleaudio/backends/sox_backend.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddleaudio/paddleaudio/backends/sox_backend.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddleaudio/utils/__init__.py b/paddleaudio/paddleaudio/compliance/__init__.py
similarity index 67%
rename from paddleaudio/utils/__init__.py
rename to paddleaudio/paddleaudio/compliance/__init__.py
index 1c1b4a90e..97043fd7b 100644
--- a/paddleaudio/utils/__init__.py
+++ b/paddleaudio/paddleaudio/compliance/__init__.py
@@ -1,6 +1,6 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
-# Licensed under the Apache License, Version 2.0 (the "License"
+# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .download import *
-from .env import *
-from .error import *
-from .log import *
-from .time import *
diff --git a/paddleaudio/paddleaudio/compliance/kaldi.py b/paddleaudio/paddleaudio/compliance/kaldi.py
new file mode 100644
index 000000000..8cb9b6660
--- /dev/null
+++ b/paddleaudio/paddleaudio/compliance/kaldi.py
@@ -0,0 +1,638 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from torchaudio(https://github.com/pytorch/audio)
+import math
+from typing import Tuple
+
+import paddle
+from paddle import Tensor
+
+from ..functional import create_dct
+from ..functional.window import get_window
+
+__all__ = [
+ 'spectrogram',
+ 'fbank',
+ 'mfcc',
+]
+
+# window types
+HANNING = 'hann'
+HAMMING = 'hamming'
+POVEY = 'povey'
+RECTANGULAR = 'rect'
+BLACKMAN = 'blackman'
+
+
+def _get_epsilon(dtype):
+ return paddle.to_tensor(1e-07, dtype=dtype)
+
+
+def _next_power_of_2(x: int) -> int:
+ return 1 if x == 0 else 2**(x - 1).bit_length()
+
+
+def _get_strided(waveform: Tensor,
+ window_size: int,
+ window_shift: int,
+ snip_edges: bool) -> Tensor:
+ assert waveform.dim() == 1
+ num_samples = waveform.shape[0]
+
+ if snip_edges:
+ if num_samples < window_size:
+ return paddle.empty((0, 0), dtype=waveform.dtype)
+ else:
+ m = 1 + (num_samples - window_size) // window_shift
+ else:
+ reversed_waveform = paddle.flip(waveform, [0])
+ m = (num_samples + (window_shift // 2)) // window_shift
+ pad = window_size // 2 - window_shift // 2
+ pad_right = reversed_waveform
+ if pad > 0:
+ pad_left = reversed_waveform[-pad:]
+ waveform = paddle.concat((pad_left, waveform, pad_right), axis=0)
+ else:
+ waveform = paddle.concat((waveform[-pad:], pad_right), axis=0)
+
+ return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T
+
+
+def _feature_window_function(
+ window_type: str,
+ window_size: int,
+ blackman_coeff: float,
+ dtype: int, ) -> Tensor:
+ if window_type == HANNING:
+ return get_window('hann', window_size, fftbins=False, dtype=dtype)
+ elif window_type == HAMMING:
+ return get_window('hamming', window_size, fftbins=False, dtype=dtype)
+ elif window_type == POVEY:
+ return get_window(
+ 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
+ elif window_type == RECTANGULAR:
+ return paddle.ones([window_size], dtype=dtype)
+ elif window_type == BLACKMAN:
+ a = 2 * math.pi / (window_size - 1)
+ window_function = paddle.arange(window_size, dtype=dtype)
+ return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
+ (0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
+ ).astype(dtype)
+ else:
+ raise Exception('Invalid window type ' + window_type)
+
+
+def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
+ energy_floor: float) -> Tensor:
+ log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log()
+ if energy_floor == 0.0:
+ return log_energy
+ return paddle.maximum(
+ log_energy,
+ paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype))
+
+
+def _get_waveform_and_window_properties(
+ waveform: Tensor,
+ channel: int,
+ sr: int,
+ frame_shift: float,
+ frame_length: float,
+ round_to_power_of_two: bool,
+ preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
+ channel = max(channel, 0)
+ assert channel < waveform.shape[0], (
+ 'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
+ waveform = waveform[channel, :] # size (n)
+ window_shift = int(
+ sr * frame_shift *
+ 0.001) # pass frame_shift and frame_length in milliseconds
+ window_size = int(sr * frame_length * 0.001)
+ padded_window_size = _next_power_of_2(
+ window_size) if round_to_power_of_two else window_size
+
+ assert 2 <= window_size <= len(waveform), (
+ 'choose a window size {} that is [2, {}]'.format(window_size,
+ len(waveform)))
+ assert 0 < window_shift, '`window_shift` must be greater than 0'
+ assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
+ ' use `round_to_power_of_two` or change `frame_length`'
+ assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
+ assert sr > 0, '`sr` must be greater than zero'
+ return waveform, window_shift, window_size, padded_window_size
+
+
+def _get_window(waveform: Tensor,
+ padded_window_size: int,
+ window_size: int,
+ window_shift: int,
+ window_type: str,
+ blackman_coeff: float,
+ snip_edges: bool,
+ raw_energy: bool,
+ energy_floor: float,
+ dither: float,
+ remove_dc_offset: bool,
+ preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
+ dtype = waveform.dtype
+ epsilon = _get_epsilon(dtype)
+
+ # (m, window_size)
+ strided_input = _get_strided(waveform, window_size, window_shift,
+ snip_edges)
+
+ if dither != 0.0:
+ x = paddle.maximum(epsilon,
+ paddle.rand(strided_input.shape, dtype=dtype))
+ rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
+ strided_input = strided_input + rand_gauss * dither
+
+ if remove_dc_offset:
+ row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
+ strided_input = strided_input - row_means
+
+ if raw_energy:
+ signal_log_energy = _get_log_energy(strided_input, epsilon,
+ energy_floor) # (m)
+
+ if preemphasis_coefficient != 0.0:
+ offset_strided_input = paddle.nn.functional.pad(
+ strided_input.unsqueeze(0), (1, 0),
+ data_format='NCL',
+ mode='replicate').squeeze(0) # (m, window_size + 1)
+ strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
+ -1]
+
+ window_function = _feature_window_function(
+ window_type, window_size, blackman_coeff,
+ dtype).unsqueeze(0) # (1, window_size)
+ strided_input = strided_input * window_function # (m, window_size)
+
+ # (m, padded_window_size)
+ if padded_window_size != window_size:
+ padding_right = padded_window_size - window_size
+ strided_input = paddle.nn.functional.pad(
+ strided_input.unsqueeze(0), (0, padding_right),
+ data_format='NCL',
+ mode='constant',
+ value=0).squeeze(0)
+
+ if not raw_energy:
+ signal_log_energy = _get_log_energy(strided_input, epsilon,
+ energy_floor) # size (m)
+
+ return strided_input, signal_log_energy
+
+
+def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
+ if subtract_mean:
+ col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
+ tensor = tensor - col_means
+ return tensor
+
+
+def spectrogram(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ window_type: str=POVEY) -> Tensor:
+ """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape [C, T].
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
+
+ Returns:
+ Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames
+ depends on frame_length and frame_shift.
+ """
+ dtype = waveform.dtype
+ epsilon = _get_epsilon(dtype)
+
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
+ waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
+ preemphasis_coefficient)
+
+ strided_input, signal_log_energy = _get_window(
+ waveform, padded_window_size, window_size, window_shift, window_type,
+ blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
+ remove_dc_offset, preemphasis_coefficient)
+
+ # (m, padded_window_size // 2 + 1, 2)
+ fft = paddle.fft.rfft(strided_input)
+
+ power_spectrum = paddle.maximum(
+ fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
+ power_spectrum[:, 0] = signal_log_energy
+
+ power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
+ return power_spectrum
+
+
+def _inverse_mel_scale_scalar(mel_freq: float) -> float:
+ return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
+
+
+def _inverse_mel_scale(mel_freq: Tensor) -> Tensor:
+ return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
+
+
+def _mel_scale_scalar(freq: float) -> float:
+ return 1127.0 * math.log(1.0 + freq / 700.0)
+
+
+def _mel_scale(freq: Tensor) -> Tensor:
+ return 1127.0 * (1.0 + freq / 700.0).log()
+
+
+def _vtln_warp_freq(vtln_low_cutoff: float,
+ vtln_high_cutoff: float,
+ low_freq: float,
+ high_freq: float,
+ vtln_warp_factor: float,
+ freq: Tensor) -> Tensor:
+ assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
+ assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
+ l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
+ h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
+ scale = 1.0 / vtln_warp_factor
+ Fl = scale * l
+ Fh = scale * h
+ assert l > low_freq and h < high_freq
+ scale_left = (Fl - low_freq) / (l - low_freq)
+ scale_right = (high_freq - Fh) / (high_freq - h)
+ res = paddle.empty_like(freq)
+
+ outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
+ | paddle.greater_than(freq, paddle.to_tensor(high_freq))
+ before_l = paddle.less_than(freq, paddle.to_tensor(l))
+ before_h = paddle.less_than(freq, paddle.to_tensor(h))
+ after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
+
+ res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
+ res[before_h] = scale * freq[before_h]
+ res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
+ res[outside_low_high_freq] = freq[outside_low_high_freq]
+
+ return res
+
+
+def _vtln_warp_mel_freq(vtln_low_cutoff: float,
+ vtln_high_cutoff: float,
+ low_freq,
+ high_freq: float,
+ vtln_warp_factor: float,
+ mel_freq: Tensor) -> Tensor:
+ return _mel_scale(
+ _vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
+ vtln_warp_factor, _inverse_mel_scale(mel_freq)))
+
+
+def _get_mel_banks(num_bins: int,
+ window_length_padded: int,
+ sample_freq: float,
+ low_freq: float,
+ high_freq: float,
+ vtln_low: float,
+ vtln_high: float,
+ vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
+ assert num_bins > 3, 'Must have at least 3 mel bins'
+ assert window_length_padded % 2 == 0
+ num_fft_bins = window_length_padded / 2
+ nyquist = 0.5 * sample_freq
+
+ if high_freq <= 0.0:
+ high_freq += nyquist
+
+ assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
+ ('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
+
+ fft_bin_width = sample_freq / window_length_padded
+ mel_low_freq = _mel_scale_scalar(low_freq)
+ mel_high_freq = _mel_scale_scalar(high_freq)
+
+ mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
+
+ if vtln_high < 0.0:
+ vtln_high += nyquist
+
+ assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
+ (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
+ ('Bad values in options: vtln-low {} and vtln-high {}, versus '
+ 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
+
+ bin = paddle.arange(num_bins).unsqueeze(1)
+ left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
+ center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
+ right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
+
+ if vtln_warp_factor != 1.0:
+ left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
+ vtln_warp_factor, left_mel)
+ center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
+ high_freq, vtln_warp_factor,
+ center_mel)
+ right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
+ high_freq, vtln_warp_factor, right_mel)
+
+ center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
+ # (1, num_fft_bins)
+ mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
+
+ # (num_bins, num_fft_bins)
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+
+ if vtln_warp_factor == 1.0:
+ bins = paddle.maximum(
+ paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
+ else:
+ bins = paddle.zeros_like(up_slope)
+ up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
+ mel, center_mel)
+ down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
+ mel, right_mel)
+ bins[up_idx] = up_slope[up_idx]
+ bins[down_idx] = down_slope[down_idx]
+
+ return bins, center_freqs
+
+
+def fbank(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ high_freq: float=0.0,
+ htk_compat: bool=False,
+ low_freq: float=20.0,
+ n_mels: int=23,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ use_energy: bool=False,
+ use_log_fbank: bool=True,
+ use_power: bool=True,
+ vtln_high: float=-500.0,
+ vtln_low: float=100.0,
+ vtln_warp: float=1.0,
+ window_type: str=POVEY) -> Tensor:
+ """Compute and return filter banks from a waveform. The output is identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape [C, T].
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
+ htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
+ low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
+ n_mels (int, optional): Number of output mel bins. Defaults to 23.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
+ use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
+ use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
+ vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
+
+ Returns:
+ Tensor: A filter banks tensor with shape (m, n_mels).
+ """
+ dtype = waveform.dtype
+
+ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
+ waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
+ preemphasis_coefficient)
+
+ strided_input, signal_log_energy = _get_window(
+ waveform, padded_window_size, window_size, window_shift, window_type,
+ blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
+ remove_dc_offset, preemphasis_coefficient)
+
+ # (m, padded_window_size // 2 + 1)
+ spectrum = paddle.fft.rfft(strided_input).abs()
+ if use_power:
+ spectrum = spectrum.pow(2.)
+
+ # (n_mels, padded_window_size // 2)
+ mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
+ high_freq, vtln_low, vtln_high, vtln_warp)
+ mel_energies = mel_energies.astype(dtype)
+
+ # (n_mels, padded_window_size // 2 + 1)
+ mel_energies = paddle.nn.functional.pad(
+ mel_energies.unsqueeze(0), (0, 1),
+ data_format='NCL',
+ mode='constant',
+ value=0).squeeze(0)
+
+ # (m, n_mels)
+ mel_energies = paddle.mm(spectrum, mel_energies.T)
+ if use_log_fbank:
+ mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
+
+ if use_energy:
+ signal_log_energy = signal_log_energy.unsqueeze(1)
+ if htk_compat:
+ mel_energies = paddle.concat(
+ (mel_energies, signal_log_energy), axis=1)
+ else:
+ mel_energies = paddle.concat(
+ (signal_log_energy, mel_energies), axis=1)
+
+ # (m, n_mels + 1)
+ mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
+ return mel_energies
+
+
+def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
+ dct_matrix = create_dct(n_mels, n_mels, 'ortho')
+ dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
+ dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
+ return dct_matrix
+
+
+def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
+ i = paddle.arange(n_mfcc)
+ return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
+ cepstral_lifter)
+
+
+def mfcc(waveform: Tensor,
+ blackman_coeff: float=0.42,
+ cepstral_lifter: float=22.0,
+ channel: int=-1,
+ dither: float=0.0,
+ energy_floor: float=1.0,
+ frame_length: float=25.0,
+ frame_shift: float=10.0,
+ high_freq: float=0.0,
+ htk_compat: bool=False,
+ low_freq: float=20.0,
+ n_mfcc: int=13,
+ n_mels: int=23,
+ preemphasis_coefficient: float=0.97,
+ raw_energy: bool=True,
+ remove_dc_offset: bool=True,
+ round_to_power_of_two: bool=True,
+ sr: int=16000,
+ snip_edges: bool=True,
+ subtract_mean: bool=False,
+ use_energy: bool=False,
+ vtln_high: float=-500.0,
+ vtln_low: float=100.0,
+ vtln_warp: float=1.0,
+ window_type: str=POVEY) -> Tensor:
+ """Compute and return mel frequency cepstral coefficients from a waveform. The output is
+ identical to Kaldi's.
+
+ Args:
+ waveform (Tensor): A waveform tensor with shape [C, T].
+ blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
+ cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
+ channel (int, optional): Select the channel of waveform. Defaults to -1.
+ dither (float, optional): Dithering constant . Defaults to 0.0.
+ energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
+ frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
+ frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
+ high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
+ htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
+ low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
+ n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
+ n_mels (int, optional): Number of output mel bins. Defaults to 23.
+ preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
+ raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
+ remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
+ round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
+ to FFT. Defaults to True.
+ sr (int, optional): Sample rate of input waveform. Defaults to 16000.
+ snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
+ is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
+ subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
+ use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
+ vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
+ vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
+ vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
+ window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
+
+ Returns:
+ Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc).
+ """
+ assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
+ n_mfcc, n_mels)
+
+ dtype = waveform.dtype
+
+ # (m, n_mels + use_energy)
+ feature = fbank(
+ waveform=waveform,
+ blackman_coeff=blackman_coeff,
+ channel=channel,
+ dither=dither,
+ energy_floor=energy_floor,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ high_freq=high_freq,
+ htk_compat=htk_compat,
+ low_freq=low_freq,
+ n_mels=n_mels,
+ preemphasis_coefficient=preemphasis_coefficient,
+ raw_energy=raw_energy,
+ remove_dc_offset=remove_dc_offset,
+ round_to_power_of_two=round_to_power_of_two,
+ sr=sr,
+ snip_edges=snip_edges,
+ subtract_mean=False,
+ use_energy=use_energy,
+ use_log_fbank=True,
+ use_power=True,
+ vtln_high=vtln_high,
+ vtln_low=vtln_low,
+ vtln_warp=vtln_warp,
+ window_type=window_type)
+
+ if use_energy:
+ # (m)
+ signal_log_energy = feature[:, n_mels if htk_compat else 0]
+ mel_offset = int(not htk_compat)
+ feature = feature[:, mel_offset:(n_mels + mel_offset)]
+
+ # (n_mels, n_mfcc)
+ dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
+
+ # (m, n_mfcc)
+ feature = feature.matmul(dct_matrix)
+
+ if cepstral_lifter != 0.0:
+ # (1, n_mfcc)
+ lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
+ feature *= lifter_coeffs.astype(dtype=dtype)
+
+ if use_energy:
+ feature[:, 0] = signal_log_energy
+
+ if htk_compat:
+ energy = feature[:, 0].unsqueeze(1) # (m, 1)
+ feature = feature[:, 1:] # (m, n_mfcc - 1)
+ if not use_energy:
+ energy *= math.sqrt(2)
+
+ feature = paddle.concat((feature, energy), axis=1)
+
+ feature = _subtract_column_mean(feature, subtract_mean)
+ return feature
diff --git a/paddleaudio/features/core.py b/paddleaudio/paddleaudio/compliance/librosa.py
similarity index 79%
rename from paddleaudio/features/core.py
rename to paddleaudio/paddleaudio/compliance/librosa.py
index 01925ec62..167795c37 100644
--- a/paddleaudio/features/core.py
+++ b/paddleaudio/paddleaudio/compliance/librosa.py
@@ -21,11 +21,13 @@ import numpy as np
import scipy
from numpy import ndarray as array
from numpy.lib.stride_tricks import as_strided
-from scipy.signal import get_window
+from scipy import signal
+from ..backends import depth_convert
from ..utils import ParameterError
__all__ = [
+ # dsp
'stft',
'mfcc',
'hz_to_mel',
@@ -38,6 +40,12 @@ __all__ = [
'spectrogram',
'mu_encode',
'mu_decode',
+ # augmentation
+ 'depth_augment',
+ 'spect_augment',
+ 'random_crop1d',
+ 'random_crop2d',
+ 'adaptive_spect_augment',
]
@@ -303,7 +311,7 @@ def stft(x: array,
if hop_length is None:
hop_length = int(win_length // 4)
- fft_window = get_window(window, win_length, fftbins=True)
+ fft_window = signal.get_window(window, win_length, fftbins=True)
# Pad the window out to n_fft size
fft_window = pad_center(fft_window, n_fft)
@@ -576,3 +584,145 @@ def mu_decode(y: array, mu: int=255, quantized: bool=True) -> array:
y = y * 2 / mu - 1
x = np.sign(y) / mu * ((1 + mu)**np.abs(y) - 1)
return x
+
+
+def randint(high: int) -> int:
+ """Generate one random integer in range [0 high)
+
+ This is a helper function for random data augmentaiton
+ """
+ return int(np.random.randint(0, high=high))
+
+
+def rand() -> float:
+ """Generate one floating-point number in range [0 1)
+
+ This is a helper function for random data augmentaiton
+ """
+ return float(np.random.rand(1))
+
+
+def depth_augment(y: array,
+ choices: List=['int8', 'int16'],
+ probs: List[float]=[0.5, 0.5]) -> array:
+ """ Audio depth augmentation
+
+ Do audio depth augmentation to simulate the distortion brought by quantization.
+ """
+ assert len(probs) == len(
+ choices
+ ), 'number of choices {} must be equal to size of probs {}'.format(
+ len(choices), len(probs))
+ depth = np.random.choice(choices, p=probs)
+ src_depth = y.dtype
+ y1 = depth_convert(y, depth)
+ y2 = depth_convert(y1, src_depth)
+
+ return y2
+
+
+def adaptive_spect_augment(spect: array, tempo_axis: int=0,
+ level: float=0.1) -> array:
+ """Do adpative spectrogram augmentation
+
+ The level of the augmentation is gowern by the paramter level,
+ ranging from 0 to 1, with 0 represents no augmentation。
+
+ """
+ assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
+ if tempo_axis == 0:
+ nt, nf = spect.shape
+ else:
+ nf, nt = spect.shape
+
+ time_mask_width = int(nt * level * 0.5)
+ freq_mask_width = int(nf * level * 0.5)
+
+ num_time_mask = int(10 * level)
+ num_freq_mask = int(10 * level)
+
+ if tempo_axis == 0:
+ for _ in range(num_time_mask):
+ start = randint(nt - time_mask_width)
+ spect[start:start + time_mask_width, :] = 0
+ for _ in range(num_freq_mask):
+ start = randint(nf - freq_mask_width)
+ spect[:, start:start + freq_mask_width] = 0
+ else:
+ for _ in range(num_time_mask):
+ start = randint(nt - time_mask_width)
+ spect[:, start:start + time_mask_width] = 0
+ for _ in range(num_freq_mask):
+ start = randint(nf - freq_mask_width)
+ spect[start:start + freq_mask_width, :] = 0
+
+ return spect
+
+
+def spect_augment(spect: array,
+ tempo_axis: int=0,
+ max_time_mask: int=3,
+ max_freq_mask: int=3,
+ max_time_mask_width: int=30,
+ max_freq_mask_width: int=20) -> array:
+ """Do spectrogram augmentation in both time and freq axis
+
+ Reference:
+
+ """
+ assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
+ if tempo_axis == 0:
+ nt, nf = spect.shape
+ else:
+ nf, nt = spect.shape
+
+ num_time_mask = randint(max_time_mask)
+ num_freq_mask = randint(max_freq_mask)
+
+ time_mask_width = randint(max_time_mask_width)
+ freq_mask_width = randint(max_freq_mask_width)
+
+ if tempo_axis == 0:
+ for _ in range(num_time_mask):
+ start = randint(nt - time_mask_width)
+ spect[start:start + time_mask_width, :] = 0
+ for _ in range(num_freq_mask):
+ start = randint(nf - freq_mask_width)
+ spect[:, start:start + freq_mask_width] = 0
+ else:
+ for _ in range(num_time_mask):
+ start = randint(nt - time_mask_width)
+ spect[:, start:start + time_mask_width] = 0
+ for _ in range(num_freq_mask):
+ start = randint(nf - freq_mask_width)
+ spect[start:start + freq_mask_width, :] = 0
+
+ return spect
+
+
+def random_crop1d(y: array, crop_len: int) -> array:
+ """ Do random cropping on 1d input signal
+
+ The input is a 1d signal, typically a sound waveform
+ """
+ if y.ndim != 1:
+ 'only accept 1d tensor or numpy array'
+ n = len(y)
+ idx = randint(n - crop_len)
+ return y[idx:idx + crop_len]
+
+
+def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
+ """ Do random cropping for 2D array, typically a spectrogram.
+
+ The cropping is done in temporal direction on the time-freq input signal.
+ """
+ if tempo_axis >= s.ndim:
+ raise ParameterError('axis out of range')
+
+ n = s.shape[tempo_axis]
+ idx = randint(high=n - crop_len)
+ sli = [slice(None) for i in range(s.ndim)]
+ sli[tempo_axis] = slice(idx, idx + crop_len)
+ out = s[tuple(sli)]
+ return out
diff --git a/paddleaudio/datasets/__init__.py b/paddleaudio/paddleaudio/datasets/__init__.py
similarity index 90%
rename from paddleaudio/datasets/__init__.py
rename to paddleaudio/paddleaudio/datasets/__init__.py
index 8d2fdab46..5c5f03694 100644
--- a/paddleaudio/datasets/__init__.py
+++ b/paddleaudio/paddleaudio/datasets/__init__.py
@@ -15,10 +15,3 @@ from .esc50 import ESC50
from .gtzan import GTZAN
from .tess import TESS
from .urban_sound import UrbanSound8K
-
-__all__ = [
- 'ESC50',
- 'UrbanSound8K',
- 'GTZAN',
- 'TESS',
-]
diff --git a/paddleaudio/datasets/dataset.py b/paddleaudio/paddleaudio/datasets/dataset.py
similarity index 96%
rename from paddleaudio/datasets/dataset.py
rename to paddleaudio/paddleaudio/datasets/dataset.py
index 7a57fd6cc..06e2df6d0 100644
--- a/paddleaudio/datasets/dataset.py
+++ b/paddleaudio/paddleaudio/datasets/dataset.py
@@ -17,8 +17,8 @@ import numpy as np
import paddle
from ..backends import load as load_audio
-from ..features import melspectrogram
-from ..features import mfcc
+from ..compliance.librosa import melspectrogram
+from ..compliance.librosa import mfcc
feat_funcs = {
'raw': None,
diff --git a/paddleaudio/datasets/esc50.py b/paddleaudio/paddleaudio/datasets/esc50.py
similarity index 100%
rename from paddleaudio/datasets/esc50.py
rename to paddleaudio/paddleaudio/datasets/esc50.py
diff --git a/paddleaudio/datasets/gtzan.py b/paddleaudio/paddleaudio/datasets/gtzan.py
similarity index 100%
rename from paddleaudio/datasets/gtzan.py
rename to paddleaudio/paddleaudio/datasets/gtzan.py
diff --git a/paddleaudio/datasets/tess.py b/paddleaudio/paddleaudio/datasets/tess.py
similarity index 100%
rename from paddleaudio/datasets/tess.py
rename to paddleaudio/paddleaudio/datasets/tess.py
diff --git a/paddleaudio/datasets/urban_sound.py b/paddleaudio/paddleaudio/datasets/urban_sound.py
similarity index 100%
rename from paddleaudio/datasets/urban_sound.py
rename to paddleaudio/paddleaudio/datasets/urban_sound.py
diff --git a/paddleaudio/features/__init__.py b/paddleaudio/paddleaudio/features/__init__.py
similarity index 82%
rename from paddleaudio/features/__init__.py
rename to paddleaudio/paddleaudio/features/__init__.py
index d8ac7c4b9..00781397f 100644
--- a/paddleaudio/features/__init__.py
+++ b/paddleaudio/paddleaudio/features/__init__.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .augment import *
-from .core import *
-from .spectrum import *
+from .layers import LogMelSpectrogram
+from .layers import MelSpectrogram
+from .layers import MFCC
+from .layers import Spectrogram
diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py
new file mode 100644
index 000000000..6afd234a0
--- /dev/null
+++ b/paddleaudio/paddleaudio/features/layers.py
@@ -0,0 +1,350 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+from typing import Optional
+from typing import Union
+
+import paddle
+import paddle.nn as nn
+
+from ..functional import compute_fbank_matrix
+from ..functional import create_dct
+from ..functional import power_to_db
+from ..functional.window import get_window
+
+__all__ = [
+ 'Spectrogram',
+ 'MelSpectrogram',
+ 'LogMelSpectrogram',
+ 'MFCC',
+]
+
+
+class Spectrogram(nn.Layer):
+ def __init__(self,
+ n_fft: int=512,
+ hop_length: Optional[int]=None,
+ win_length: Optional[int]=None,
+ window: str='hann',
+ power: float=2.0,
+ center: bool=True,
+ pad_mode: str='reflect',
+ dtype: str=paddle.float32):
+ """Compute spectrogram of a given signal, typically an audio waveform.
+ The spectorgram is defined as the complex norm of the short-time
+ Fourier transformation.
+ Parameters:
+ n_fft (int): the number of frequency components of the discrete Fourier transform.
+ The default value is 2048,
+ hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
+ The default value is None.
+ win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
+ The default value is None.
+ window (str): the name of the window function applied to the single before the Fourier transform.
+ The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
+ 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
+ The default value is 'hann'
+ power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
+ center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
+ If False, frame t begins at x[t * hop_length]
+ The default value is True
+ pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
+ and 'constant'. The default value is 'reflect'.
+ dtype (str): the data type of input and window.
+ Notes:
+ The Spectrogram transform relies on STFT transform to compute the spectrogram.
+ By default, the weights are not learnable. To fine-tune the Fourier coefficients,
+ set stop_gradient=False before training.
+ For more information, see STFT().
+ """
+ super(Spectrogram, self).__init__()
+
+ assert power > 0, 'Power of spectrogram must be > 0.'
+ self.power = power
+
+ if win_length is None:
+ win_length = n_fft
+
+ self.fft_window = get_window(
+ window, win_length, fftbins=True, dtype=dtype)
+ self._stft = partial(
+ paddle.signal.stft,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=self.fft_window,
+ center=center,
+ pad_mode=pad_mode)
+ self.register_buffer('fft_window', self.fft_window)
+
+ def forward(self, x):
+ stft = self._stft(x)
+ spectrogram = paddle.pow(paddle.abs(stft), self.power)
+ return spectrogram
+
+
+class MelSpectrogram(nn.Layer):
+ def __init__(self,
+ sr: int=22050,
+ n_fft: int=512,
+ hop_length: Optional[int]=None,
+ win_length: Optional[int]=None,
+ window: str='hann',
+ power: float=2.0,
+ center: bool=True,
+ pad_mode: str='reflect',
+ n_mels: int=64,
+ f_min: float=50.0,
+ f_max: Optional[float]=None,
+ htk: bool=False,
+ norm: Union[str, float]='slaney',
+ dtype: str=paddle.float32):
+ """Compute the melspectrogram of a given signal, typically an audio waveform.
+ The melspectrogram is also known as filterbank or fbank feature in audio community.
+ It is computed by multiplying spectrogram with Mel filter bank matrix.
+ Parameters:
+ sr(int): the audio sample rate.
+ The default value is 22050.
+ n_fft(int): the number of frequency components of the discrete Fourier transform.
+ The default value is 2048,
+ hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
+ The default value is None.
+ win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
+ The default value is None.
+ window(str): the name of the window function applied to the single before the Fourier transform.
+ The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
+ 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
+ The default value is 'hann'
+ power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
+ center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
+ If False, frame t begins at x[t * hop_length]
+ The default value is True
+ pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
+ and 'constant'.
+ The default value is 'reflect'.
+ n_mels(int): the mel bins.
+ f_min(float): the lower cut-off frequency, below which the filter response is zero.
+ f_max(float): the upper cut-off frequency, above which the filter response is zeros.
+ htk(bool): whether to use HTK formula in computing fbank matrix.
+ norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
+ You can specify norm=1.0/2.0 to use customized p-norm normalization.
+ dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
+ accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
+ """
+ super(MelSpectrogram, self).__init__()
+
+ self._spectrogram = Spectrogram(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ power=power,
+ center=center,
+ pad_mode=pad_mode,
+ dtype=dtype)
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max
+ self.htk = htk
+ self.norm = norm
+ if f_max is None:
+ f_max = sr // 2
+ self.fbank_matrix = compute_fbank_matrix(
+ sr=sr,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ f_min=f_min,
+ f_max=f_max,
+ htk=htk,
+ norm=norm,
+ dtype=dtype) # float64 for better numerical results
+ self.register_buffer('fbank_matrix', self.fbank_matrix)
+
+ def forward(self, x):
+ spect_feature = self._spectrogram(x)
+ mel_feature = paddle.matmul(self.fbank_matrix, spect_feature)
+ return mel_feature
+
+
+class LogMelSpectrogram(nn.Layer):
+ def __init__(self,
+ sr: int=22050,
+ n_fft: int=512,
+ hop_length: Optional[int]=None,
+ win_length: Optional[int]=None,
+ window: str='hann',
+ power: float=2.0,
+ center: bool=True,
+ pad_mode: str='reflect',
+ n_mels: int=64,
+ f_min: float=50.0,
+ f_max: Optional[float]=None,
+ htk: bool=False,
+ norm: Union[str, float]='slaney',
+ ref_value: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=None,
+ dtype: str=paddle.float32):
+ """Compute log-mel-spectrogram(also known as LogFBank) feature of a given signal,
+ typically an audio waveform.
+ Parameters:
+ sr (int): the audio sample rate.
+ The default value is 22050.
+ n_fft (int): the number of frequency components of the discrete Fourier transform.
+ The default value is 2048,
+ hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
+ The default value is None.
+ win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
+ The default value is None.
+ window (str): the name of the window function applied to the single before the Fourier transform.
+ The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
+ 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
+ The default value is 'hann'
+ center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
+ If False, frame t begins at x[t * hop_length]
+ The default value is True
+ pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
+ and 'constant'.
+ The default value is 'reflect'.
+ n_mels (int): the mel bins.
+ f_min (float): the lower cut-off frequency, below which the filter response is zero.
+ f_max (float): the upper cut-off frequency, above which the filter response is zeros.
+ htk (bool): whether to use HTK formula in computing fbank matrix.
+ norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
+ You can specify norm=1.0/2.0 to use customized p-norm normalization.
+ ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down.
+ amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin).
+ top_db (float): the maximum db value of resulting spectrum, above which the
+ spectrum is clipped(to top_db).
+ dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
+ accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
+ """
+ super(LogMelSpectrogram, self).__init__()
+
+ self._melspectrogram = MelSpectrogram(
+ sr=sr,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ power=power,
+ center=center,
+ pad_mode=pad_mode,
+ n_mels=n_mels,
+ f_min=f_min,
+ f_max=f_max,
+ htk=htk,
+ norm=norm,
+ dtype=dtype)
+
+ self.ref_value = ref_value
+ self.amin = amin
+ self.top_db = top_db
+
+ def forward(self, x):
+ mel_feature = self._melspectrogram(x)
+ log_mel_feature = power_to_db(
+ mel_feature,
+ ref_value=self.ref_value,
+ amin=self.amin,
+ top_db=self.top_db)
+ return log_mel_feature
+
+
+class MFCC(nn.Layer):
+ def __init__(self,
+ sr: int=22050,
+ n_mfcc: int=40,
+ n_fft: int=512,
+ hop_length: Optional[int]=None,
+ win_length: Optional[int]=None,
+ window: str='hann',
+ power: float=2.0,
+ center: bool=True,
+ pad_mode: str='reflect',
+ n_mels: int=64,
+ f_min: float=50.0,
+ f_max: Optional[float]=None,
+ htk: bool=False,
+ norm: Union[str, float]='slaney',
+ ref_value: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=None,
+ dtype: str=paddle.float32):
+ """Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms.
+
+ Parameters:
+ sr(int): the audio sample rate.
+ The default value is 22050.
+ n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 40.
+ n_fft (int): the number of frequency components of the discrete Fourier transform.
+ The default value is 2048,
+ hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
+ The default value is None.
+ win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
+ The default value is None.
+ window (str): the name of the window function applied to the single before the Fourier transform.
+ The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
+ 'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
+ The default value is 'hann'
+ power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
+ center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
+ If False, frame t begins at x[t * hop_length]
+ The default value is True
+ pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
+ and 'constant'.
+ The default value is 'reflect'.
+ n_mels (int): the mel bins.
+ f_min (float): the lower cut-off frequency, below which the filter response is zero.
+ f_max (float): the upper cut-off frequency, above which the filter response is zeros.
+ htk (bool): whether to use HTK formula in computing fbank matrix.
+ norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
+ You can specify norm=1.0/2.0 to use customized p-norm normalization.
+ ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down.
+ amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin).
+ top_db (float): the maximum db value of resulting spectrum, above which the
+ spectrum is clipped(to top_db).
+ dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
+ accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
+ """
+ super(MFCC, self).__init__()
+ assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
+ n_mfcc, n_mels)
+ self._log_melspectrogram = LogMelSpectrogram(
+ sr=sr,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ power=power,
+ center=center,
+ pad_mode=pad_mode,
+ n_mels=n_mels,
+ f_min=f_min,
+ f_max=f_max,
+ htk=htk,
+ norm=norm,
+ ref_value=ref_value,
+ amin=amin,
+ top_db=top_db,
+ dtype=dtype)
+ self.dct_matrix = create_dct(n_mfcc=n_mfcc, n_mels=n_mels, dtype=dtype)
+ self.register_buffer('dct_matrix', self.dct_matrix)
+
+ def forward(self, x):
+ log_mel_feature = self._log_melspectrogram(x)
+ mfcc = paddle.matmul(
+ log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose(
+ (0, 2, 1)) # (B, n_mels, L)
+ return mfcc
diff --git a/paddleaudio/paddleaudio/functional/__init__.py b/paddleaudio/paddleaudio/functional/__init__.py
new file mode 100644
index 000000000..c85232df1
--- /dev/null
+++ b/paddleaudio/paddleaudio/functional/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .functional import compute_fbank_matrix
+from .functional import create_dct
+from .functional import fft_frequencies
+from .functional import hz_to_mel
+from .functional import mel_frequencies
+from .functional import mel_to_hz
+from .functional import power_to_db
diff --git a/paddleaudio/paddleaudio/functional/functional.py b/paddleaudio/paddleaudio/functional/functional.py
new file mode 100644
index 000000000..c5ab30453
--- /dev/null
+++ b/paddleaudio/paddleaudio/functional/functional.py
@@ -0,0 +1,265 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Modified from librosa(https://github.com/librosa/librosa)
+import math
+from typing import Optional
+from typing import Union
+
+import paddle
+
+__all__ = [
+ 'hz_to_mel',
+ 'mel_to_hz',
+ 'mel_frequencies',
+ 'fft_frequencies',
+ 'compute_fbank_matrix',
+ 'power_to_db',
+ 'create_dct',
+]
+
+
+def hz_to_mel(freq: Union[paddle.Tensor, float],
+ htk: bool=False) -> Union[paddle.Tensor, float]:
+ """Convert Hz to Mels.
+ Parameters:
+ freq: the input tensor of arbitrary shape, or a single floating point number.
+ htk: use HTK formula to do the conversion.
+ The default value is False.
+ Returns:
+ The frequencies represented in Mel-scale.
+ """
+
+ if htk:
+ if isinstance(freq, paddle.Tensor):
+ return 2595.0 * paddle.log10(1.0 + freq / 700.0)
+ else:
+ return 2595.0 * math.log10(1.0 + freq / 700.0)
+
+ # Fill in the linear part
+ f_min = 0.0
+ f_sp = 200.0 / 3
+
+ mels = (freq - f_min) / f_sp
+
+ # Fill in the log-scale part
+
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = math.log(6.4) / 27.0 # step size for log region
+
+ if isinstance(freq, paddle.Tensor):
+ target = min_log_mel + paddle.log(
+ freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
+ mask = (freq > min_log_hz).astype(freq.dtype)
+ mels = target * mask + mels * (
+ 1 - mask) # will replace by masked_fill OP in future
+ else:
+ if freq >= min_log_hz:
+ mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
+
+ return mels
+
+
+def mel_to_hz(mel: Union[float, paddle.Tensor],
+ htk: bool=False) -> Union[float, paddle.Tensor]:
+ """Convert mel bin numbers to frequencies.
+ Parameters:
+ mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
+ htk: use HTK formula to do the conversion.
+ Returns:
+ The frequencies represented in hz.
+ """
+ if htk:
+ return 700.0 * (10.0**(mel / 2595.0) - 1.0)
+
+ f_min = 0.0
+ f_sp = 200.0 / 3
+ freqs = f_min + f_sp * mel
+ # And now the nonlinear scale
+ min_log_hz = 1000.0 # beginning of log region (Hz)
+ min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
+ logstep = math.log(6.4) / 27.0 # step size for log region
+ if isinstance(mel, paddle.Tensor):
+ target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
+ mask = (mel > min_log_mel).astype(mel.dtype)
+ freqs = target * mask + freqs * (
+ 1 - mask) # will replace by masked_fill OP in future
+ else:
+ if mel >= min_log_mel:
+ freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
+
+ return freqs
+
+
+def mel_frequencies(n_mels: int=64,
+ f_min: float=0.0,
+ f_max: float=11025.0,
+ htk: bool=False,
+ dtype: str=paddle.float32):
+ """Compute mel frequencies.
+ Parameters:
+ n_mels(int): number of Mel bins.
+ f_min(float): the lower cut-off frequency, below which the filter response is zero.
+ f_max(float): the upper cut-off frequency, above which the filter response is zero.
+ htk(bool): whether to use htk formula.
+ dtype(str): the datatype of the return frequencies.
+ Returns:
+ The frequencies represented in Mel-scale
+ """
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ min_mel = hz_to_mel(f_min, htk=htk)
+ max_mel = hz_to_mel(f_max, htk=htk)
+ mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
+ freqs = mel_to_hz(mels, htk=htk)
+ return freqs
+
+
+def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
+ """Compute fourier frequencies.
+ Parameters:
+ sr(int): the audio sample rate.
+ n_fft(float): the number of fft bins.
+ dtype(str): the datatype of the return frequencies.
+ Returns:
+ The frequencies represented in hz.
+ """
+ return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
+
+
+def compute_fbank_matrix(sr: int,
+ n_fft: int,
+ n_mels: int=64,
+ f_min: float=0.0,
+ f_max: Optional[float]=None,
+ htk: bool=False,
+ norm: Union[str, float]='slaney',
+ dtype: str=paddle.float32):
+ """Compute fbank matrix.
+ Parameters:
+ sr(int): the audio sample rate.
+ n_fft(int): the number of fft bins.
+ n_mels(int): the number of Mel bins.
+ f_min(float): the lower cut-off frequency, below which the filter response is zero.
+ f_max(float): the upper cut-off frequency, above which the filter response is zero.
+ htk: whether to use htk formula.
+ return_complex(bool): whether to return complex matrix. If True, the matrix will
+ be complex type. Otherwise, the real and image part will be stored in the last
+ axis of returned tensor.
+ dtype(str): the datatype of the returned fbank matrix.
+ Returns:
+ The fbank matrix of shape (n_mels, int(1+n_fft//2)).
+ Shape:
+ output: (n_mels, int(1+n_fft//2))
+ """
+
+ if f_max is None:
+ f_max = float(sr) / 2
+
+ # Initialize the weights
+ weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
+
+ # Center freqs of each FFT bin
+ fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
+
+ # 'Center freqs' of mel bands - uniformly spaced between limits
+ mel_f = mel_frequencies(
+ n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
+
+ fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
+ ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
+ #ramps = np.subtract.outer(mel_f, fftfreqs)
+
+ for i in range(n_mels):
+ # lower and upper slopes for all bins
+ lower = -ramps[i] / fdiff[i]
+ upper = ramps[i + 2] / fdiff[i + 1]
+
+ # .. then intersect them with each other and zero
+ weights[i] = paddle.maximum(
+ paddle.zeros_like(lower), paddle.minimum(lower, upper))
+
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ if norm == 'slaney':
+ enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
+ weights *= enorm.unsqueeze(1)
+ elif isinstance(norm, int) or isinstance(norm, float):
+ weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
+
+ return weights
+
+
+def power_to_db(magnitude: paddle.Tensor,
+ ref_value: float=1.0,
+ amin: float=1e-10,
+ top_db: Optional[float]=None) -> paddle.Tensor:
+ """Convert a power spectrogram (amplitude squared) to decibel (dB) units.
+ The function computes the scaling ``10 * log10(x / ref)`` in a numerically
+ stable way.
+ Parameters:
+ magnitude(Tensor): the input magnitude tensor of any shape.
+ ref_value(float): the reference value. If smaller than 1.0, the db level
+ of the signal will be pulled up accordingly. Otherwise, the db level
+ is pushed down.
+ amin(float): the minimum value of input magnitude, below which the input
+ magnitude is clipped(to amin).
+ top_db(float): the maximum db value of resulting spectrum, above which the
+ spectrum is clipped(to top_db).
+ Returns:
+ The spectrogram in log-scale.
+ shape:
+ input: any shape
+ output: same as input
+ """
+ if amin <= 0:
+ raise Exception("amin must be strictly positive")
+
+ if ref_value <= 0:
+ raise Exception("ref_value must be strictly positive")
+
+ ones = paddle.ones_like(magnitude)
+ log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
+ log_spec -= 10.0 * math.log10(max(ref_value, amin))
+
+ if top_db is not None:
+ if top_db < 0:
+ raise Exception("top_db must be non-negative")
+ log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
+
+ return log_spec
+
+
+def create_dct(n_mfcc: int,
+ n_mels: int,
+ norm: Optional[str]='ortho',
+ dtype: Optional[str]=paddle.float32) -> paddle.Tensor:
+ """Create a discrete cosine transform(DCT) matrix.
+
+ Parameters:
+ n_mfcc (int): Number of mel frequency cepstral coefficients.
+ n_mels (int): Number of mel filterbanks.
+ norm (str, optional): Normalizaiton type. Defaults to 'ortho'.
+ Returns:
+ Tensor: The DCT matrix with shape (n_mels, n_mfcc).
+ """
+ n = paddle.arange(n_mels, dtype=dtype)
+ k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
+ dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
+ k) # size (n_mfcc, n_mels)
+ if norm is None:
+ dct *= 2.0
+ else:
+ assert norm == "ortho"
+ dct[0] *= 1.0 / math.sqrt(2.0)
+ dct *= math.sqrt(2.0 / float(n_mels))
+ return dct.T
diff --git a/paddleaudio/features/window.py b/paddleaudio/paddleaudio/functional/window.py
similarity index 98%
rename from paddleaudio/features/window.py
rename to paddleaudio/paddleaudio/functional/window.py
index 629989fc9..f321b38ef 100644
--- a/paddleaudio/features/window.py
+++ b/paddleaudio/paddleaudio/functional/window.py
@@ -20,6 +20,19 @@ from paddle import Tensor
__all__ = [
'get_window',
+
+ # windows
+ 'taylor',
+ 'hamming',
+ 'hann',
+ 'tukey',
+ 'kaiser',
+ 'gaussian',
+ 'exponential',
+ 'triang',
+ 'bohman',
+ 'blackman',
+ 'cosine',
]
@@ -73,6 +86,21 @@ def general_gaussian(M: int, p, sig, sym: bool=True,
return _truncate(w, needs_trunc)
+def general_cosine(M: int, a: float, sym: bool=True,
+ dtype: str='float64') -> Tensor:
+ """Compute a generic weighted sum of cosine terms window.
+ This function is consistent with scipy.signal.windows.general_cosine().
+ """
+ if _len_guards(M):
+ return paddle.ones((M, ), dtype=dtype)
+ M, needs_trunc = _extend(M, sym)
+ fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
+ w = paddle.zeros((M, ), dtype=dtype)
+ for k in range(len(a)):
+ w += a[k] * paddle.cos(k * fac)
+ return _truncate(w, needs_trunc)
+
+
def general_hamming(M: int, alpha: float, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a generalized Hamming window.
@@ -143,21 +171,6 @@ def taylor(M: int,
return _truncate(w, needs_trunc)
-def general_cosine(M: int, a: float, sym: bool=True,
- dtype: str='float64') -> Tensor:
- """Compute a generic weighted sum of cosine terms window.
- This function is consistent with scipy.signal.windows.general_cosine().
- """
- if _len_guards(M):
- return paddle.ones((M, ), dtype=dtype)
- M, needs_trunc = _extend(M, sym)
- fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
- w = paddle.zeros((M, ), dtype=dtype)
- for k in range(len(a)):
- w += a[k] * paddle.cos(k * fac)
- return _truncate(w, needs_trunc)
-
-
def hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with
@@ -375,6 +388,7 @@ def cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
return _truncate(w, needs_trunc)
+## factory function
def get_window(window: Union[str, Tuple[str, float]],
win_length: int,
fftbins: bool=True,
diff --git a/paddleaudio/backends/__init__.py b/paddleaudio/paddleaudio/io/__init__.py
similarity index 96%
rename from paddleaudio/backends/__init__.py
rename to paddleaudio/paddleaudio/io/__init__.py
index f2f77ffea..185a92b8d 100644
--- a/paddleaudio/backends/__init__.py
+++ b/paddleaudio/paddleaudio/io/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from .audio import *
diff --git a/paddleaudio/paddleaudio/metric/__init__.py b/paddleaudio/paddleaudio/metric/__init__.py
new file mode 100644
index 000000000..a96530ff6
--- /dev/null
+++ b/paddleaudio/paddleaudio/metric/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .dtw import dtw_distance
+from .mcd import mcd_distance
diff --git a/paddleaudio/paddleaudio/metric/dtw.py b/paddleaudio/paddleaudio/metric/dtw.py
new file mode 100644
index 000000000..d27f56e28
--- /dev/null
+++ b/paddleaudio/paddleaudio/metric/dtw.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+from dtaidistance import dtw_ndim
+
+__all__ = [
+ 'dtw_distance',
+]
+
+
+def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float:
+ """dtw distance
+
+ Dynamic Time Warping.
+ This function keeps a compact matrix, not the full warping paths matrix.
+ Uses dynamic programming to compute:
+
+ wps[i, j] = (s1[i]-s2[j])**2 + min(
+ wps[i-1, j ] + penalty, // vertical / insertion / expansion
+ wps[i , j-1] + penalty, // horizontal / deletion / compression
+ wps[i-1, j-1]) // diagonal / match
+ dtw = sqrt(wps[-1, -1])
+
+ Args:
+ xs (np.ndarray): ref sequence, [T,D]
+ ys (np.ndarray): hyp sequence, [T,D]
+
+ Returns:
+ float: dtw distance
+ """
+ return dtw_ndim.distance(xs, ys)
diff --git a/paddleaudio/paddleaudio/metric/mcd.py b/paddleaudio/paddleaudio/metric/mcd.py
new file mode 100644
index 000000000..465cd5a45
--- /dev/null
+++ b/paddleaudio/paddleaudio/metric/mcd.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import mcd.metrics_fast as mt
+import numpy as np
+from mcd import dtw
+
+__all__ = [
+ 'mcd_distance',
+]
+
+
+def mcd_distance(xs: np.ndarray, ys: np.ndarray, cost_fn=mt.logSpecDbDist):
+ """Mel cepstral distortion (MCD), dtw distance.
+
+ Dynamic Time Warping.
+ Uses dynamic programming to compute:
+ wps[i, j] = cost_fn(xs[i], ys[j]) + min(
+ wps[i-1, j ], // vertical / insertion / expansion
+ wps[i , j-1], // horizontal / deletion / compression
+ wps[i-1, j-1]) // diagonal / match
+ dtw = sqrt(wps[-1, -1])
+
+ Cost Function:
+ logSpecDbConst = 10.0 / math.log(10.0) * math.sqrt(2.0)
+ def logSpecDbDist(x, y):
+ diff = x - y
+ return logSpecDbConst * math.sqrt(np.inner(diff, diff))
+
+ Args:
+ xs (np.ndarray): ref sequence, [T,D]
+ ys (np.ndarray): hyp sequence, [T,D]
+
+ Returns:
+ float: dtw distance
+ """
+ min_cost, path = dtw.dtw(xs, ys, cost_fn)
+ return min_cost
diff --git a/paddleaudio/paddleaudio/sox_effects/__init__.py b/paddleaudio/paddleaudio/sox_effects/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddleaudio/paddleaudio/sox_effects/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddleaudio/paddleaudio/utils/__init__.py b/paddleaudio/paddleaudio/utils/__init__.py
new file mode 100644
index 000000000..afb9cedd8
--- /dev/null
+++ b/paddleaudio/paddleaudio/utils/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .download import decompress
+from .download import download_and_decompress
+from .download import load_state_dict_from_url
+from .env import DATA_HOME
+from .env import MODEL_HOME
+from .env import PPAUDIO_HOME
+from .env import USER_HOME
+from .error import ParameterError
+from .log import Logger
+from .log import logger
+from .time import seconds_to_hms
+from .time import Timer
diff --git a/paddleaudio/utils/download.py b/paddleaudio/paddleaudio/utils/download.py
similarity index 94%
rename from paddleaudio/utils/download.py
rename to paddleaudio/paddleaudio/utils/download.py
index 45a8e57ba..4658352f9 100644
--- a/paddleaudio/utils/download.py
+++ b/paddleaudio/paddleaudio/utils/download.py
@@ -22,6 +22,12 @@ from .log import logger
download.logger = logger
+__all__ = [
+ 'decompress',
+ 'download_and_decompress',
+ 'load_state_dict_from_url',
+]
+
def decompress(file: str):
"""
diff --git a/paddleaudio/utils/env.py b/paddleaudio/paddleaudio/utils/env.py
similarity index 95%
rename from paddleaudio/utils/env.py
rename to paddleaudio/paddleaudio/utils/env.py
index 59c6b6219..a2d14b89e 100644
--- a/paddleaudio/utils/env.py
+++ b/paddleaudio/paddleaudio/utils/env.py
@@ -20,6 +20,13 @@ PPAUDIO_HOME --> the root directory for storing PaddleAudio related data. D
'''
import os
+__all__ = [
+ 'USER_HOME',
+ 'PPAUDIO_HOME',
+ 'MODEL_HOME',
+ 'DATA_HOME',
+]
+
def _get_user_home():
return os.path.expanduser('~')
diff --git a/paddleaudio/utils/error.py b/paddleaudio/paddleaudio/utils/error.py
similarity index 100%
rename from paddleaudio/utils/error.py
rename to paddleaudio/paddleaudio/utils/error.py
diff --git a/paddleaudio/utils/log.py b/paddleaudio/paddleaudio/utils/log.py
similarity index 98%
rename from paddleaudio/utils/log.py
rename to paddleaudio/paddleaudio/utils/log.py
index 5e7db68a9..5656b286a 100644
--- a/paddleaudio/utils/log.py
+++ b/paddleaudio/paddleaudio/utils/log.py
@@ -19,7 +19,10 @@ import time
import colorlog
-loggers = {}
+__all__ = [
+ 'Logger',
+ 'logger',
+]
log_config = {
'DEBUG': {
diff --git a/paddleaudio/utils/time.py b/paddleaudio/paddleaudio/utils/time.py
similarity index 97%
rename from paddleaudio/utils/time.py
rename to paddleaudio/paddleaudio/utils/time.py
index 6f0c7585b..105208f91 100644
--- a/paddleaudio/utils/time.py
+++ b/paddleaudio/paddleaudio/utils/time.py
@@ -14,6 +14,11 @@
import math
import time
+__all__ = [
+ 'Timer',
+ 'seconds_to_hms',
+]
+
class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)'''
diff --git a/setup_audio.py b/paddleaudio/setup.py
similarity index 57%
rename from setup_audio.py
rename to paddleaudio/setup.py
index 212049987..930f86e41 100644
--- a/setup_audio.py
+++ b/paddleaudio/setup.py
@@ -11,19 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import glob
+import os
+
import setuptools
+from setuptools.command.install import install
+from setuptools.command.test import test
# set the version here
-VERSION = '0.1.0'
+VERSION = '0.2.0'
+
+
+# Inspired by the example at https://pytest.org/latest/goodpractises.html
+class TestCommand(test):
+ def finalize_options(self):
+ test.finalize_options(self)
+ self.test_args = []
+ self.test_suite = True
+
+ def run(self):
+ self.run_benchmark()
+ super(TestCommand, self).run()
+
+ def run_tests(self):
+ # Run nose ensuring that argv simulates running nosetests directly
+ import nose
+ nose.run_exit(argv=['nosetests', '-w', 'tests'])
+
+ def run_benchmark(self):
+ for benchmark_item in glob.glob('tests/benchmark/*py'):
+ os.system(f'pytest {benchmark_item}')
+
+
+class InstallCommand(install):
+ def run(self):
+ install.run(self)
def write_version_py(filename='paddleaudio/__init__.py'):
- import paddleaudio
- if hasattr(paddleaudio,
- "__version__") and paddleaudio.__version__ == VERSION:
- return
with open(filename, "a") as f:
- f.write(f"\n__version__ = '{VERSION}'\n")
+ f.write(f"__version__ = '{VERSION}'")
def remove_version_py(filename='paddleaudio/__init__.py'):
@@ -35,6 +62,7 @@ def remove_version_py(filename='paddleaudio/__init__.py'):
f.write(line)
+remove_version_py()
write_version_py()
setuptools.setup(
@@ -59,6 +87,18 @@ setuptools.setup(
'resampy >= 0.2.2',
'soundfile >= 0.9.0',
'colorlog',
- ], )
+ 'dtaidistance >= 2.3.6',
+ 'mcd >= 0.4',
+ ],
+ extras_require={
+ 'test': [
+ 'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',
+ 'torchaudio==0.10.2', 'pytest-benchmark'
+ ],
+ },
+ cmdclass={
+ 'install': InstallCommand,
+ 'test': TestCommand,
+ }, )
remove_version_py()
diff --git a/speechx/examples/.gitkeep b/paddleaudio/tests/.gitkeep
similarity index 100%
rename from speechx/examples/.gitkeep
rename to paddleaudio/tests/.gitkeep
diff --git a/paddleaudio/tests/backends/__init__.py b/paddleaudio/tests/backends/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddleaudio/tests/backends/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddleaudio/tests/backends/base.py b/paddleaudio/tests/backends/base.py
new file mode 100644
index 000000000..a67191887
--- /dev/null
+++ b/paddleaudio/tests/backends/base.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import unittest
+import urllib.request
+
+mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
+multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav'
+
+
+class BackendTest(unittest.TestCase):
+ def setUp(self):
+ self.initWavInput()
+
+ def initWavInput(self):
+ self.files = []
+ for url in [mono_channel_wav, multi_channels_wav]:
+ if not os.path.isfile(os.path.basename(url)):
+ urllib.request.urlretrieve(url, os.path.basename(url))
+ self.files.append(os.path.basename(url))
+
+ def initParmas(self):
+ raise NotImplementedError
diff --git a/paddleaudio/tests/backends/soundfile/__init__.py b/paddleaudio/tests/backends/soundfile/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddleaudio/tests/backends/soundfile/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddleaudio/tests/backends/soundfile/test_io.py b/paddleaudio/tests/backends/soundfile/test_io.py
new file mode 100644
index 000000000..0f7580a40
--- /dev/null
+++ b/paddleaudio/tests/backends/soundfile/test_io.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import filecmp
+import os
+import unittest
+
+import numpy as np
+import soundfile as sf
+
+import paddleaudio
+from ..base import BackendTest
+
+
+class TestIO(BackendTest):
+ def test_load_mono_channel(self):
+ sf_data, sf_sr = sf.read(self.files[0])
+ pa_data, pa_sr = paddleaudio.load(
+ self.files[0], normal=False, dtype='float64')
+
+ self.assertEqual(sf_data.dtype, pa_data.dtype)
+ self.assertEqual(sf_sr, pa_sr)
+ np.testing.assert_array_almost_equal(sf_data, pa_data)
+
+ def test_load_multi_channels(self):
+ sf_data, sf_sr = sf.read(self.files[1])
+ sf_data = sf_data.T # Channel dim first
+ pa_data, pa_sr = paddleaudio.load(
+ self.files[1], mono=False, normal=False, dtype='float64')
+
+ self.assertEqual(sf_data.dtype, pa_data.dtype)
+ self.assertEqual(sf_sr, pa_sr)
+ np.testing.assert_array_almost_equal(sf_data, pa_data)
+
+ def test_save_mono_channel(self):
+ waveform, sr = np.random.randint(
+ low=-32768, high=32768, size=(48000), dtype=np.int16), 16000
+ sf_tmp_file = 'sf_tmp.wav'
+ pa_tmp_file = 'pa_tmp.wav'
+
+ sf.write(sf_tmp_file, waveform, sr)
+ paddleaudio.save(waveform, sr, pa_tmp_file)
+
+ self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
+ for file in [sf_tmp_file, pa_tmp_file]:
+ os.remove(file)
+
+ def test_save_multi_channels(self):
+ waveform, sr = np.random.randint(
+ low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000
+ sf_tmp_file = 'sf_tmp.wav'
+ pa_tmp_file = 'pa_tmp.wav'
+
+ sf.write(sf_tmp_file, waveform.T, sr)
+ paddleaudio.save(waveform.T, sr, pa_tmp_file)
+
+ self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
+ for file in [sf_tmp_file, pa_tmp_file]:
+ os.remove(file)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/benchmark/README.md b/paddleaudio/tests/benchmark/README.md
new file mode 100644
index 000000000..b9034100d
--- /dev/null
+++ b/paddleaudio/tests/benchmark/README.md
@@ -0,0 +1,39 @@
+# 1. Prepare
+First, install `pytest-benchmark` via pip.
+```sh
+pip install pytest-benchmark
+```
+
+# 2. Run
+Run the specific script for profiling.
+```sh
+pytest melspectrogram.py
+```
+
+Result:
+```sh
+========================================================================== test session starts ==========================================================================
+platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0
+benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
+rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio
+plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0
+collected 4 items
+
+melspectrogram.py .... [100%]
+
+
+-------------------------------------------------------------------------------------------------- benchmark: 4 tests -------------------------------------------------------------------------------------------------
+Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1
+test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1
+test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1
+test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Legend:
+ Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
+ OPS: Operations Per Second, computed as 1 / Mean
+========================================================================== 4 passed in 21.12s ===========================================================================
+
+```
diff --git a/paddleaudio/tests/benchmark/log_melspectrogram.py b/paddleaudio/tests/benchmark/log_melspectrogram.py
new file mode 100644
index 000000000..5230acd42
--- /dev/null
+++ b/paddleaudio/tests/benchmark/log_melspectrogram.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import urllib.request
+
+import librosa
+import numpy as np
+import paddle
+import torch
+import torchaudio
+
+import paddleaudio
+
+wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
+if not os.path.isfile(os.path.basename(wav_url)):
+ urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
+
+waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
+waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
+waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
+
+# Feature conf
+mel_conf = {
+ 'sr': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+}
+
+mel_conf_torchaudio = {
+ 'sample_rate': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+ 'norm': 'slaney',
+ 'mel_scale': 'slaney',
+}
+
+
+def enable_cpu_device():
+ paddle.set_device('cpu')
+
+
+def enable_gpu_device():
+ paddle.set_device('gpu')
+
+
+log_mel_extractor = paddleaudio.features.LogMelSpectrogram(
+ **mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype)
+
+
+def log_melspectrogram():
+ return log_mel_extractor(waveform_tensor).squeeze(0)
+
+
+def test_log_melspect_cpu(benchmark):
+ enable_cpu_device()
+ feature_paddleaudio = benchmark(log_melspectrogram)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_log_melspect_gpu(benchmark):
+ enable_gpu_device()
+ feature_paddleaudio = benchmark(log_melspectrogram)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=2)
+
+
+mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
+ **mel_conf_torchaudio, f_min=0.0)
+amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0)
+
+
+def melspectrogram_torchaudio():
+ return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
+
+
+def log_melspectrogram_torchaudio():
+ mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch)
+ return amplitude_to_DB(mel_specgram).squeeze(0)
+
+
+def test_log_melspect_cpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
+
+ mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
+ waveform_tensor_torch = waveform_tensor_torch.to('cpu')
+ amplitude_to_DB = amplitude_to_DB.to('cpu')
+
+ feature_paddleaudio = benchmark(log_melspectrogram_torchaudio)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_log_melspect_gpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
+
+ mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
+ waveform_tensor_torch = waveform_tensor_torch.to('cuda')
+ amplitude_to_DB = amplitude_to_DB.to('cuda')
+
+ feature_torchaudio = benchmark(log_melspectrogram_torchaudio)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_torchaudio.cpu(), decimal=2)
diff --git a/paddleaudio/tests/benchmark/melspectrogram.py b/paddleaudio/tests/benchmark/melspectrogram.py
new file mode 100644
index 000000000..e0b79b45a
--- /dev/null
+++ b/paddleaudio/tests/benchmark/melspectrogram.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import urllib.request
+
+import librosa
+import numpy as np
+import paddle
+import torch
+import torchaudio
+
+import paddleaudio
+
+wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
+if not os.path.isfile(os.path.basename(wav_url)):
+ urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
+
+waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
+waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
+waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
+
+# Feature conf
+mel_conf = {
+ 'sr': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+}
+
+mel_conf_torchaudio = {
+ 'sample_rate': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+ 'norm': 'slaney',
+ 'mel_scale': 'slaney',
+}
+
+
+def enable_cpu_device():
+ paddle.set_device('cpu')
+
+
+def enable_gpu_device():
+ paddle.set_device('gpu')
+
+
+mel_extractor = paddleaudio.features.MelSpectrogram(
+ **mel_conf, f_min=0.0, dtype=waveform_tensor.dtype)
+
+
+def melspectrogram():
+ return mel_extractor(waveform_tensor).squeeze(0)
+
+
+def test_melspect_cpu(benchmark):
+ enable_cpu_device()
+ feature_paddleaudio = benchmark(melspectrogram)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_melspect_gpu(benchmark):
+ enable_gpu_device()
+ feature_paddleaudio = benchmark(melspectrogram)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
+ **mel_conf_torchaudio, f_min=0.0)
+
+
+def melspectrogram_torchaudio():
+ return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
+
+
+def test_melspect_cpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mel_extractor_torchaudio
+ mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
+ waveform_tensor_torch = waveform_tensor_torch.to('cpu')
+ feature_paddleaudio = benchmark(melspectrogram_torchaudio)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_melspect_gpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mel_extractor_torchaudio
+ mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
+ waveform_tensor_torch = waveform_tensor_torch.to('cuda')
+ feature_torchaudio = benchmark(melspectrogram_torchaudio)
+ feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_torchaudio.cpu(), decimal=3)
diff --git a/paddleaudio/tests/benchmark/mfcc.py b/paddleaudio/tests/benchmark/mfcc.py
new file mode 100644
index 000000000..2572ff33d
--- /dev/null
+++ b/paddleaudio/tests/benchmark/mfcc.py
@@ -0,0 +1,122 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import urllib.request
+
+import librosa
+import numpy as np
+import paddle
+import torch
+import torchaudio
+
+import paddleaudio
+
+wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
+if not os.path.isfile(os.path.basename(wav_url)):
+ urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
+
+waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
+waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
+waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
+
+# Feature conf
+mel_conf = {
+ 'sr': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+}
+mfcc_conf = {
+ 'n_mfcc': 20,
+ 'top_db': 80.0,
+}
+mfcc_conf.update(mel_conf)
+
+mel_conf_torchaudio = {
+ 'sample_rate': sr,
+ 'n_fft': 512,
+ 'hop_length': 128,
+ 'n_mels': 40,
+ 'norm': 'slaney',
+ 'mel_scale': 'slaney',
+}
+mfcc_conf_torchaudio = {
+ 'sample_rate': sr,
+ 'n_mfcc': 20,
+}
+
+
+def enable_cpu_device():
+ paddle.set_device('cpu')
+
+
+def enable_gpu_device():
+ paddle.set_device('gpu')
+
+
+mfcc_extractor = paddleaudio.features.MFCC(
+ **mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype)
+
+
+def mfcc():
+ return mfcc_extractor(waveform_tensor).squeeze(0)
+
+
+def test_mfcc_cpu(benchmark):
+ enable_cpu_device()
+ feature_paddleaudio = benchmark(mfcc)
+ feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_mfcc_gpu(benchmark):
+ enable_gpu_device()
+ feature_paddleaudio = benchmark(mfcc)
+ feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+del mel_conf_torchaudio['sample_rate']
+mfcc_extractor_torchaudio = torchaudio.transforms.MFCC(
+ **mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio)
+
+
+def mfcc_torchaudio():
+ return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
+
+
+def test_mfcc_cpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mfcc_extractor_torchaudio
+
+ mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu')
+ waveform_tensor_torch = waveform_tensor_torch.to('cpu')
+
+ feature_paddleaudio = benchmark(mfcc_torchaudio)
+ feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddleaudio, decimal=3)
+
+
+def test_mfcc_gpu_torchaudio(benchmark):
+ global waveform_tensor_torch, mfcc_extractor_torchaudio
+
+ mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda')
+ waveform_tensor_torch = waveform_tensor_torch.to('cuda')
+
+ feature_torchaudio = benchmark(mfcc_torchaudio)
+ feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_torchaudio.cpu(), decimal=3)
diff --git a/paddleaudio/tests/features/__init__.py b/paddleaudio/tests/features/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddleaudio/tests/features/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddleaudio/tests/features/base.py b/paddleaudio/tests/features/base.py
new file mode 100644
index 000000000..725e1e2e7
--- /dev/null
+++ b/paddleaudio/tests/features/base.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import unittest
+import urllib.request
+
+import numpy as np
+import paddle
+
+from paddleaudio import load
+
+wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
+
+
+class FeatTest(unittest.TestCase):
+ def setUp(self):
+ self.initParmas()
+ self.initWavInput()
+ self.setUpDevice()
+
+ def setUpDevice(self, device='cpu'):
+ paddle.set_device(device)
+
+ def initWavInput(self, url=wav_url):
+ if not os.path.isfile(os.path.basename(url)):
+ urllib.request.urlretrieve(url, os.path.basename(url))
+ self.waveform, self.sr = load(os.path.abspath(os.path.basename(url)))
+ self.waveform = self.waveform.astype(
+ np.float32
+ ) # paddlespeech.s2t.transform.spectrogram only supports float32
+ dim = len(self.waveform.shape)
+
+ assert dim in [1, 2]
+ if dim == 1:
+ self.waveform = np.expand_dims(self.waveform, 0)
+
+ def initParmas(self):
+ raise NotImplementedError
diff --git a/paddleaudio/tests/features/test_istft.py b/paddleaudio/tests/features/test_istft.py
new file mode 100644
index 000000000..23371200b
--- /dev/null
+++ b/paddleaudio/tests/features/test_istft.py
@@ -0,0 +1,49 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import paddle
+
+from .base import FeatTest
+from paddleaudio.functional.window import get_window
+from paddlespeech.s2t.transform.spectrogram import IStft
+from paddlespeech.s2t.transform.spectrogram import Stft
+
+
+class TestIstft(FeatTest):
+ def initParmas(self):
+ self.n_fft = 512
+ self.hop_length = 128
+ self.window_str = 'hann'
+
+ def test_istft(self):
+ ps_stft = Stft(self.n_fft, self.hop_length)
+ ps_res = ps_stft(
+ self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
+ x = paddle.to_tensor(ps_res)
+
+ ps_istft = IStft(self.hop_length)
+ ps_res = ps_istft(ps_res.T)
+
+ window = get_window(
+ self.window_str, self.n_fft, dtype=self.waveform.dtype)
+ pd_res = paddle.signal.istft(
+ x, self.n_fft, self.hop_length, window=window)
+
+ np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/features/test_kaldi.py b/paddleaudio/tests/features/test_kaldi.py
new file mode 100644
index 000000000..6e826aaa7
--- /dev/null
+++ b/paddleaudio/tests/features/test_kaldi.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import paddle
+import torch
+import torchaudio
+
+import paddleaudio
+from .base import FeatTest
+
+
+class TestKaldi(FeatTest):
+ def initParmas(self):
+ self.window_size = 1024
+ self.dtype = 'float32'
+
+ def test_window(self):
+ t_hann_window = torch.hann_window(
+ self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}'))
+ t_hamm_window = torch.hamming_window(
+ self.window_size,
+ periodic=False,
+ alpha=0.54,
+ beta=0.46,
+ dtype=eval(f'torch.{self.dtype}'))
+ t_povey_window = torch.hann_window(
+ self.window_size, periodic=False,
+ dtype=eval(f'torch.{self.dtype}')).pow(0.85)
+
+ p_hann_window = paddleaudio.functional.window.get_window(
+ 'hann',
+ self.window_size,
+ fftbins=False,
+ dtype=eval(f'paddle.{self.dtype}'))
+ p_hamm_window = paddleaudio.functional.window.get_window(
+ 'hamming',
+ self.window_size,
+ fftbins=False,
+ dtype=eval(f'paddle.{self.dtype}'))
+ p_povey_window = paddleaudio.functional.window.get_window(
+ 'hann',
+ self.window_size,
+ fftbins=False,
+ dtype=eval(f'paddle.{self.dtype}')).pow(0.85)
+
+ np.testing.assert_array_almost_equal(t_hann_window, p_hann_window)
+ np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window)
+ np.testing.assert_array_almost_equal(t_povey_window, p_povey_window)
+
+ def test_fbank(self):
+ ta_features = torchaudio.compliance.kaldi.fbank(
+ torch.from_numpy(self.waveform.astype(self.dtype)))
+ pa_features = paddleaudio.compliance.kaldi.fbank(
+ paddle.to_tensor(self.waveform.astype(self.dtype)))
+ np.testing.assert_array_almost_equal(
+ ta_features, pa_features, decimal=4)
+
+ def test_mfcc(self):
+ ta_features = torchaudio.compliance.kaldi.mfcc(
+ torch.from_numpy(self.waveform.astype(self.dtype)))
+ pa_features = paddleaudio.compliance.kaldi.mfcc(
+ paddle.to_tensor(self.waveform.astype(self.dtype)))
+ np.testing.assert_array_almost_equal(
+ ta_features, pa_features, decimal=4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/features/test_librosa.py b/paddleaudio/tests/features/test_librosa.py
new file mode 100644
index 000000000..cf0c98c72
--- /dev/null
+++ b/paddleaudio/tests/features/test_librosa.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import librosa
+import numpy as np
+import paddle
+
+import paddleaudio
+from .base import FeatTest
+from paddleaudio.functional.window import get_window
+
+
+class TestLibrosa(FeatTest):
+ def initParmas(self):
+ self.n_fft = 512
+ self.hop_length = 128
+ self.n_mels = 40
+ self.n_mfcc = 20
+ self.fmin = 0.0
+ self.window_str = 'hann'
+ self.pad_mode = 'reflect'
+ self.top_db = 80.0
+
+ def test_stft(self):
+ if len(self.waveform.shape) == 2: # (C, T)
+ self.waveform = self.waveform.squeeze(
+ 0) # 1D input for librosa.feature.melspectrogram
+
+ feature_librosa = librosa.core.stft(
+ y=self.waveform,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=None,
+ window=self.window_str,
+ center=True,
+ dtype=None,
+ pad_mode=self.pad_mode, )
+ x = paddle.to_tensor(self.waveform).unsqueeze(0)
+ window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
+ feature_paddle = paddle.signal.stft(
+ x=x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=None,
+ window=window,
+ center=True,
+ pad_mode=self.pad_mode,
+ normalized=False,
+ onesided=True, ).squeeze(0)
+
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddle, decimal=5)
+
+ def test_istft(self):
+ if len(self.waveform.shape) == 2: # (C, T)
+ self.waveform = self.waveform.squeeze(
+ 0) # 1D input for librosa.feature.melspectrogram
+
+ # Get stft result from librosa.
+ stft_matrix = librosa.core.stft(
+ y=self.waveform,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=None,
+ window=self.window_str,
+ center=True,
+ pad_mode=self.pad_mode, )
+
+ feature_librosa = librosa.core.istft(
+ stft_matrix=stft_matrix,
+ hop_length=self.hop_length,
+ win_length=None,
+ window=self.window_str,
+ center=True,
+ dtype=None,
+ length=None, )
+
+ x = paddle.to_tensor(stft_matrix).unsqueeze(0)
+ window = get_window(
+ self.window_str,
+ self.n_fft,
+ dtype=paddle.to_tensor(self.waveform).dtype)
+ feature_paddle = paddle.signal.istft(
+ x=x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=None,
+ window=window,
+ center=True,
+ normalized=False,
+ onesided=True,
+ length=None,
+ return_complex=False, ).squeeze(0)
+
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_paddle, decimal=5)
+
+ def test_mel(self):
+ feature_librosa = librosa.filters.mel(
+ sr=self.sr,
+ n_fft=self.n_fft,
+ n_mels=self.n_mels,
+ fmin=self.fmin,
+ fmax=None,
+ htk=False,
+ norm='slaney',
+ dtype=self.waveform.dtype, )
+ feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix(
+ sr=self.sr,
+ n_fft=self.n_fft,
+ n_mels=self.n_mels,
+ fmin=self.fmin,
+ fmax=None,
+ htk=False,
+ norm='slaney',
+ dtype=self.waveform.dtype, )
+ x = paddle.to_tensor(self.waveform)
+ feature_functional = paddleaudio.functional.compute_fbank_matrix(
+ sr=self.sr,
+ n_fft=self.n_fft,
+ n_mels=self.n_mels,
+ f_min=self.fmin,
+ f_max=None,
+ htk=False,
+ norm='slaney',
+ dtype=x.dtype, )
+
+ np.testing.assert_array_almost_equal(feature_librosa,
+ feature_compliance)
+ np.testing.assert_array_almost_equal(feature_librosa,
+ feature_functional)
+
+ def test_melspect(self):
+ if len(self.waveform.shape) == 2: # (C, T)
+ self.waveform = self.waveform.squeeze(
+ 0) # 1D input for librosa.feature.melspectrogram
+
+ # librosa:
+ feature_librosa = librosa.feature.melspectrogram(
+ y=self.waveform,
+ sr=self.sr,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin)
+
+ # paddleaudio.compliance.librosa:
+ feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
+ x=self.waveform,
+ sr=self.sr,
+ window_size=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin,
+ to_db=False)
+
+ # paddleaudio.features.layer
+ x = paddle.to_tensor(
+ self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
+ feature_extractor = paddleaudio.features.MelSpectrogram(
+ sr=self.sr,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ f_min=self.fmin,
+ dtype=x.dtype)
+ feature_layer = feature_extractor(x).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_compliance, decimal=5)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_layer, decimal=5)
+
+ def test_log_melspect(self):
+ if len(self.waveform.shape) == 2: # (C, T)
+ self.waveform = self.waveform.squeeze(
+ 0) # 1D input for librosa.feature.melspectrogram
+
+ # librosa:
+ feature_librosa = librosa.feature.melspectrogram(
+ y=self.waveform,
+ sr=self.sr,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin)
+ feature_librosa = librosa.power_to_db(feature_librosa, top_db=None)
+
+ # paddleaudio.compliance.librosa:
+ feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
+ x=self.waveform,
+ sr=self.sr,
+ window_size=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin)
+
+ # paddleaudio.features.layer
+ x = paddle.to_tensor(
+ self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
+ feature_extractor = paddleaudio.features.LogMelSpectrogram(
+ sr=self.sr,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ f_min=self.fmin,
+ dtype=x.dtype)
+ feature_layer = feature_extractor(x).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_compliance, decimal=5)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_layer, decimal=4)
+
+ def test_mfcc(self):
+ if len(self.waveform.shape) == 2: # (C, T)
+ self.waveform = self.waveform.squeeze(
+ 0) # 1D input for librosa.feature.melspectrogram
+
+ # librosa:
+ feature_librosa = librosa.feature.mfcc(
+ y=self.waveform,
+ sr=self.sr,
+ S=None,
+ n_mfcc=self.n_mfcc,
+ dct_type=2,
+ norm='ortho',
+ lifter=0,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin)
+
+ # paddleaudio.compliance.librosa:
+ feature_compliance = paddleaudio.compliance.librosa.mfcc(
+ x=self.waveform,
+ sr=self.sr,
+ n_mfcc=self.n_mfcc,
+ dct_type=2,
+ norm='ortho',
+ lifter=0,
+ window_size=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ fmin=self.fmin,
+ top_db=self.top_db)
+
+ # paddleaudio.features.layer
+ x = paddle.to_tensor(
+ self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
+ feature_extractor = paddleaudio.features.MFCC(
+ sr=self.sr,
+ n_mfcc=self.n_mfcc,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ f_min=self.fmin,
+ top_db=self.top_db,
+ dtype=x.dtype)
+ feature_layer = feature_extractor(x).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_compliance, decimal=4)
+ np.testing.assert_array_almost_equal(
+ feature_librosa, feature_layer, decimal=4)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/features/test_log_melspectrogram.py b/paddleaudio/tests/features/test_log_melspectrogram.py
new file mode 100644
index 000000000..6bae2df3f
--- /dev/null
+++ b/paddleaudio/tests/features/test_log_melspectrogram.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import paddle
+
+import paddleaudio
+from .base import FeatTest
+from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram
+
+
+class TestLogMelSpectrogram(FeatTest):
+ def initParmas(self):
+ self.n_fft = 512
+ self.hop_length = 128
+ self.n_mels = 40
+
+ def test_log_melspect(self):
+ ps_melspect = LogMelSpectrogram(self.sr, self.n_mels, self.n_fft,
+ self.hop_length)
+ ps_res = ps_melspect(self.waveform.T).squeeze(1).T
+
+ x = paddle.to_tensor(self.waveform)
+ # paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况
+ ps_melspect = paddleaudio.features.LogMelSpectrogram(
+ self.sr,
+ self.n_fft,
+ self.hop_length,
+ power=1.0,
+ n_mels=self.n_mels,
+ f_min=0.0)
+ pa_res = (ps_melspect(x) / 10.0).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/features/test_spectrogram.py b/paddleaudio/tests/features/test_spectrogram.py
new file mode 100644
index 000000000..50b21403b
--- /dev/null
+++ b/paddleaudio/tests/features/test_spectrogram.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import paddle
+
+import paddleaudio
+from .base import FeatTest
+from paddlespeech.s2t.transform.spectrogram import Spectrogram
+
+
+class TestSpectrogram(FeatTest):
+ def initParmas(self):
+ self.n_fft = 512
+ self.hop_length = 128
+
+ def test_spectrogram(self):
+ ps_spect = Spectrogram(self.n_fft, self.hop_length)
+ ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude
+
+ x = paddle.to_tensor(self.waveform)
+ pa_spect = paddleaudio.features.Spectrogram(
+ self.n_fft, self.hop_length, power=1.0)
+ pa_res = pa_spect(x).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddleaudio/tests/features/test_stft.py b/paddleaudio/tests/features/test_stft.py
new file mode 100644
index 000000000..c64b5ebe6
--- /dev/null
+++ b/paddleaudio/tests/features/test_stft.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import paddle
+
+from .base import FeatTest
+from paddleaudio.functional.window import get_window
+from paddlespeech.s2t.transform.spectrogram import Stft
+
+
+class TestStft(FeatTest):
+ def initParmas(self):
+ self.n_fft = 512
+ self.hop_length = 128
+ self.window_str = 'hann'
+
+ def test_stft(self):
+ ps_stft = Stft(self.n_fft, self.hop_length)
+ ps_res = ps_stft(
+ self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
+
+ x = paddle.to_tensor(self.waveform)
+ window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
+ pd_res = paddle.signal.stft(
+ x, self.n_fft, self.hop_length, window=window).squeeze(0).numpy()
+
+ np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py
index 185a92b8d..b781c4a8e 100644
--- a/paddlespeech/__init__.py
+++ b/paddlespeech/__init__.py
@@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import _locale
+
+_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py
index ab5eee6e2..f56d8a579 100644
--- a/paddlespeech/cli/cls/infer.py
+++ b/paddlespeech/cli/cls/infer.py
@@ -193,7 +193,8 @@ class CLSExecutor(BaseExecutor):
sr=feat_conf['sample_rate'],
mono=True,
dtype='float32')
- logger.info("Preprocessing audio_file:" + audio_file)
+ if isinstance(audio_file, (str, os.PathLike)):
+ logger.info("Preprocessing audio_file:" + audio_file)
# Feature extraction
feature_extractor = LogMelSpectrogram(
diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py
index d7dcc90c7..f7d64b9a9 100644
--- a/paddlespeech/cli/utils.py
+++ b/paddlespeech/cli/utils.py
@@ -192,7 +192,7 @@ class ConfigCache:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
- except:
+ except Exception as e:
self.flush()
@property
diff --git a/paddlespeech/server/__init__.py b/paddlespeech/server/__init__.py
index 384061dda..97722c0a0 100644
--- a/paddlespeech/server/__init__.py
+++ b/paddlespeech/server/__init__.py
@@ -18,6 +18,7 @@ from .base_commands import ClientHelpCommand
from .base_commands import ServerBaseCommand
from .base_commands import ServerHelpCommand
from .bin.paddlespeech_client import ASRClientExecutor
+from .bin.paddlespeech_client import CLSClientExecutor
from .bin.paddlespeech_client import TTSClientExecutor
from .bin.paddlespeech_server import ServerExecutor
diff --git a/paddlespeech/server/bin/main.py b/paddlespeech/server/bin/main.py
index 360d295ef..de5282993 100644
--- a/paddlespeech/server/bin/main.py
+++ b/paddlespeech/server/bin/main.py
@@ -34,7 +34,7 @@ def init(config):
bool:
"""
# init api
- api_list = list(config.engine_backend)
+ api_list = list(engine.split("_")[0] for engine in config.engine_list)
api_router = setup_router(api_list)
app.include_router(api_router)
diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py
index ee6ab7ad7..40f17c63c 100644
--- a/paddlespeech/server/bin/paddlespeech_client.py
+++ b/paddlespeech/server/bin/paddlespeech_client.py
@@ -31,7 +31,7 @@ from paddlespeech.cli.log import logger
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64
-__all__ = ['TTSClientExecutor', 'ASRClientExecutor']
+__all__ = ['TTSClientExecutor', 'ASRClientExecutor', 'CLSClientExecutor']
@cli_client_register(
@@ -70,13 +70,9 @@ class TTSClientExecutor(BaseExecutor):
choices=[0, 8000, 16000],
help='Sampling rate, the default is the same as the model')
self.parser.add_argument(
- '--output',
- type=str,
- default="./output.wav",
- help='Synthesized audio file')
+ '--output', type=str, default=None, help='Synthesized audio file')
- def postprocess(self, response_dict: dict, outfile: str) -> float:
- wav_base64 = response_dict["result"]["audio"]
+ def postprocess(self, wav_base64: str, outfile: str) -> float:
audio_data_byte = base64.b64decode(wav_base64)
# from byte
samples, sample_rate = soundfile.read(
@@ -93,37 +89,38 @@ class TTSClientExecutor(BaseExecutor):
else:
logger.error("The format for saving audio only supports wav or pcm")
- duration = len(samples) / sample_rate
- return duration
-
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
- try:
- url = 'http://' + args.server_ip + ":" + str(
- args.port) + '/paddlespeech/tts'
- request = {
- "text": args.input,
- "spk_id": args.spk_id,
- "speed": args.speed,
- "volume": args.volume,
- "sample_rate": args.sample_rate,
- "save_path": args.output
- }
- st = time.time()
- response = requests.post(url, json.dumps(request))
- time_consume = time.time() - st
-
- response_dict = response.json()
- duration = self.postprocess(response_dict, args.output)
+ input_ = args.input
+ server_ip = args.server_ip
+ port = args.port
+ spk_id = args.spk_id
+ speed = args.speed
+ volume = args.volume
+ sample_rate = args.sample_rate
+ output = args.output
+ try:
+ time_start = time.time()
+ res = self(
+ input=input_,
+ server_ip=server_ip,
+ port=port,
+ spk_id=spk_id,
+ speed=speed,
+ volume=volume,
+ sample_rate=sample_rate,
+ output=output)
+ time_end = time.time()
+ time_consume = time_end - time_start
+ response_dict = res.json()
logger.info(response_dict["message"])
- logger.info("Save synthesized audio successfully on %s." %
- (args.output))
- logger.info("Audio duration: %f s." % (duration))
+ logger.info("Save synthesized audio successfully on %s." % (output))
+ logger.info("Audio duration: %f s." %
+ (response_dict['result']['duration']))
logger.info("Response time: %f s." % (time_consume))
-
return True
- except BaseException:
+ except Exception as e:
logger.error("Failed to synthesized audio.")
return False
@@ -136,7 +133,7 @@ class TTSClientExecutor(BaseExecutor):
speed: float=1.0,
volume: float=1.0,
sample_rate: int=0,
- output: str="./output.wav"):
+ output: str=None):
"""
Python API to call an executor.
"""
@@ -151,20 +148,11 @@ class TTSClientExecutor(BaseExecutor):
"save_path": output
}
- try:
- st = time.time()
- response = requests.post(url, json.dumps(request))
- time_consume = time.time() - st
- response_dict = response.json()
- duration = self.postprocess(response_dict, output)
-
- print(response_dict["message"])
- print("Save synthesized audio successfully on %s." % (output))
- print("Audio duration: %f s." % (duration))
- print("Response time: %f s." % (time_consume))
- print("RTF: %f " % (time_consume / duration))
- except BaseException:
- print("Failed to synthesized audio.")
+ res = requests.post(url, json.dumps(request))
+ response_dict = res.json()
+ if not output:
+ self.postprocess(response_dict["result"]["audio"], output)
+ return res
@cli_client_register(
@@ -193,24 +181,27 @@ class ASRClientExecutor(BaseExecutor):
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
- url = 'http://' + args.server_ip + ":" + str(
- args.port) + '/paddlespeech/asr'
- audio = wav2base64(args.input)
- data = {
- "audio": audio,
- "audio_format": args.audio_format,
- "sample_rate": args.sample_rate,
- "lang": args.lang,
- }
- time_start = time.time()
+ input_ = args.input
+ server_ip = args.server_ip
+ port = args.port
+ sample_rate = args.sample_rate
+ lang = args.lang
+ audio_format = args.audio_format
+
try:
- r = requests.post(url=url, data=json.dumps(data))
- # ending Timestamp
+ time_start = time.time()
+ res = self(
+ input=input_,
+ server_ip=server_ip,
+ port=port,
+ sample_rate=sample_rate,
+ lang=lang,
+ audio_format=audio_format)
time_end = time.time()
- logger.info(r.json())
- logger.info("time cost %f s." % (time_end - time_start))
+ logger.info(res.json())
+ logger.info("Response time %f s." % (time_end - time_start))
return True
- except BaseException:
+ except Exception as e:
logger.error("Failed to speech recognition.")
return False
@@ -234,12 +225,65 @@ class ASRClientExecutor(BaseExecutor):
"sample_rate": sample_rate,
"lang": lang,
}
- time_start = time.time()
+
+ res = requests.post(url=url, data=json.dumps(data))
+ return res
+
+
+@cli_client_register(
+ name='paddlespeech_client.cls', description='visit cls service')
+class CLSClientExecutor(BaseExecutor):
+ def __init__(self):
+ super(CLSClientExecutor, self).__init__()
+ self.parser = argparse.ArgumentParser(
+ prog='paddlespeech_client.cls', add_help=True)
+ self.parser.add_argument(
+ '--server_ip', type=str, default='127.0.0.1', help='server ip')
+ self.parser.add_argument(
+ '--port', type=int, default=8090, help='server port')
+ self.parser.add_argument(
+ '--input',
+ type=str,
+ default=None,
+ help='Audio file to classify.',
+ required=True)
+ self.parser.add_argument(
+ '--topk',
+ type=int,
+ default=1,
+ help='Return topk scores of classification result.')
+
+ def execute(self, argv: List[str]) -> bool:
+ args = self.parser.parse_args(argv)
+ input_ = args.input
+ server_ip = args.server_ip
+ port = args.port
+ topk = args.topk
+
try:
- r = requests.post(url=url, data=json.dumps(data))
- # ending Timestamp
+ time_start = time.time()
+ res = self(input=input_, server_ip=server_ip, port=port, topk=topk)
time_end = time.time()
- print(r.json())
- print("time cost %f s." % (time_end - time_start))
- except BaseException:
- print("Failed to speech recognition.")
+ logger.info(res.json())
+ logger.info("Response time %f s." % (time_end - time_start))
+ return True
+ except Exception as e:
+ logger.error("Failed to speech classification.")
+ return False
+
+ @stats_wrapper
+ def __call__(self,
+ input: str,
+ server_ip: str="127.0.0.1",
+ port: int=8090,
+ topk: int=1):
+ """
+ Python API to call an executor.
+ """
+
+ url = 'http://' + server_ip + ":" + str(port) + '/paddlespeech/cls'
+ audio = wav2base64(input)
+ data = {"audio": audio, "topk": topk}
+
+ res = requests.post(url=url, data=json.dumps(data))
+ return res
diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py
index 21fc5c65e..f6a7f4295 100644
--- a/paddlespeech/server/bin/paddlespeech_server.py
+++ b/paddlespeech/server/bin/paddlespeech_server.py
@@ -62,7 +62,7 @@ class ServerExecutor(BaseExecutor):
bool:
"""
# init api
- api_list = list(config.engine_backend)
+ api_list = list(engine.split("_")[0] for engine in config.engine_list)
api_router = setup_router(api_list)
app.include_router(api_router)
@@ -103,13 +103,14 @@ class ServerStatsExecutor():
'--task',
type=str,
default=None,
- choices=['asr', 'tts'],
+ choices=['asr', 'tts', 'cls'],
help='Choose speech task.',
required=True)
- self.task_choices = ['asr', 'tts']
+ self.task_choices = ['asr', 'tts', 'cls']
self.model_name_format = {
'asr': 'Model-Language-Sample Rate',
- 'tts': 'Model-Language'
+ 'tts': 'Model-Language',
+ 'cls': 'Model-Sample Rate'
}
def show_support_models(self, pretrained_models: dict):
@@ -174,53 +175,24 @@ class ServerStatsExecutor():
)
return False
- @stats_wrapper
- def __call__(
- self,
- task: str=None, ):
- """
- Python API to call an executor.
- """
- self.task = task
- if self.task not in self.task_choices:
- print("Please input correct speech task, choices = ['asr', 'tts']")
-
- elif self.task == 'asr':
- try:
- from paddlespeech.cli.asr.infer import pretrained_models
- print(
- "Here is the table of ASR pretrained models supported in the service."
- )
- self.show_support_models(pretrained_models)
-
- # show ASR static pretrained model
- from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models
- print(
- "Here is the table of ASR static pretrained models supported in the service."
- )
- self.show_support_models(pretrained_models)
-
- except BaseException:
- print(
- "Failed to get the table of ASR pretrained models supported in the service."
- )
-
- elif self.task == 'tts':
+ elif self.task == 'cls':
try:
- from paddlespeech.cli.tts.infer import pretrained_models
- print(
- "Here is the table of TTS pretrained models supported in the service."
+ from paddlespeech.cli.cls.infer import pretrained_models
+ logger.info(
+ "Here is the table of CLS pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
- # show TTS static pretrained model
- from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models
- print(
- "Here is the table of TTS static pretrained models supported in the service."
+ # show CLS static pretrained model
+ from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models
+ logger.info(
+ "Here is the table of CLS static pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
+ return True
except BaseException:
- print(
- "Failed to get the table of TTS pretrained models supported in the service."
+ logger.error(
+ "Failed to get the table of CLS pretrained models supported in the service."
)
+ return False
diff --git a/paddlespeech/server/conf/application.yaml b/paddlespeech/server/conf/application.yaml
index 6dcae74a9..2b1a05998 100644
--- a/paddlespeech/server/conf/application.yaml
+++ b/paddlespeech/server/conf/application.yaml
@@ -1,27 +1,137 @@
# This is the parameter configuration file for PaddleSpeech Serving.
-##################################################################
-# SERVER SETTING #
-##################################################################
-host: '127.0.0.1'
+#################################################################################
+# SERVER SETTING #
+#################################################################################
+host: 127.0.0.1
port: 8090
-##################################################################
-# CONFIG FILE #
-##################################################################
-# add engine backend type (Options: asr, tts) and config file here.
-# Adding a speech task to engine_backend means starting the service.
-engine_backend:
- asr: 'conf/asr/asr.yaml'
- tts: 'conf/tts/tts.yaml'
-
-# The engine_type of speech task needs to keep the same type as the config file of speech task.
-# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
-# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
-#
-# add engine type (Options: python, inference)
-engine_type:
- asr: 'python'
- tts: 'python'
+# The task format in the engin_list is: _
+# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
+engine_list: ['asr_python', 'tts_python', 'cls_python']
+
+
+#################################################################################
+# ENGINE CONFIG #
+#################################################################################
+
+################################### ASR #########################################
+################### speech task: asr; engine_type: python #######################
+asr_python:
+ model: 'conformer_wenetspeech'
+ lang: 'zh'
+ sample_rate: 16000
+ cfg_path: # [optional]
+ ckpt_path: # [optional]
+ decode_method: 'attention_rescoring'
+ force_yes: True
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: asr; engine_type: inference #######################
+asr_inference:
+ # model_type choices=['deepspeech2offline_aishell']
+ model_type: 'deepspeech2offline_aishell'
+ am_model: # the pdmodel file of am static model [optional]
+ am_params: # the pdiparams file of am static model [optional]
+ lang: 'zh'
+ sample_rate: 16000
+ cfg_path:
+ decode_method:
+ force_yes: True
+
+ am_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+
+################################### TTS #########################################
+################### speech task: tts; engine_type: python #######################
+tts_python:
+ # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
+ # 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
+ # 'fastspeech2_vctk']
+ am: 'fastspeech2_csmsc'
+ am_config:
+ am_ckpt:
+ am_stat:
+ phones_dict:
+ tones_dict:
+ speaker_dict:
+ spk_id: 0
+
+ # voc (vocoder) choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
+ # 'pwgan_vctk', 'mb_melgan_csmsc']
+ voc: 'pwgan_csmsc'
+ voc_config:
+ voc_ckpt:
+ voc_stat:
+
+ # others
+ lang: 'zh'
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: tts; engine_type: inference #######################
+tts_inference:
+ # am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
+ am: 'fastspeech2_csmsc'
+ am_model: # the pdmodel file of your am static model (XX.pdmodel)
+ am_params: # the pdiparams file of your am static model (XX.pdipparams)
+ am_sample_rate: 24000
+ phones_dict:
+ tones_dict:
+ speaker_dict:
+ spk_id: 0
+
+ am_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+ # voc (vocoder) choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
+ voc: 'pwgan_csmsc'
+ voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
+ voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
+ voc_sample_rate: 24000
+
+ voc_predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
+
+ # others
+ lang: 'zh'
+
+
+################################### CLS #########################################
+################### speech task: cls; engine_type: python #######################
+cls_python:
+ # model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
+ model: 'panns_cnn14'
+ cfg_path: # [optional] Config of cls task.
+ ckpt_path: # [optional] Checkpoint file of model.
+ label_file: # [optional] Label file of cls task.
+ device: # set 'gpu:id' or 'cpu'
+
+
+################### speech task: cls; engine_type: inference #######################
+cls_inference:
+ # model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
+ model_type: 'panns_cnn14'
+ cfg_path:
+ model_path: # the pdmodel file of am static model [optional]
+ params_path: # the pdiparams file of am static model [optional]
+ label_file: # [optional] Label file of cls task.
+
+ predictor_conf:
+ device: # set 'gpu:id' or 'cpu'
+ switch_ir_optim: True
+ glog_info: False # True -> print glog
+ summary: True # False -> do not show predictor config
diff --git a/paddlespeech/server/conf/asr/asr.yaml b/paddlespeech/server/conf/asr/asr.yaml
deleted file mode 100644
index a6743b775..000000000
--- a/paddlespeech/server/conf/asr/asr.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-model: 'conformer_wenetspeech'
-lang: 'zh'
-sample_rate: 16000
-cfg_path: # [optional]
-ckpt_path: # [optional]
-decode_method: 'attention_rescoring'
-force_yes: True
-device: # set 'gpu:id' or 'cpu'
diff --git a/paddlespeech/server/conf/asr/asr_pd.yaml b/paddlespeech/server/conf/asr/asr_pd.yaml
deleted file mode 100644
index 4c415ac79..000000000
--- a/paddlespeech/server/conf/asr/asr_pd.yaml
+++ /dev/null
@@ -1,26 +0,0 @@
-# This is the parameter configuration file for ASR server.
-# These are the static models that support paddle inference.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['deepspeech2offline_aishell'] TODO
-##################################################################
-model_type: 'deepspeech2offline_aishell'
-am_model: # the pdmodel file of am static model [optional]
-am_params: # the pdiparams file of am static model [optional]
-lang: 'zh'
-sample_rate: 16000
-cfg_path:
-decode_method:
-force_yes: True
-
-am_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-
-##################################################################
-# OTHERS #
-##################################################################
diff --git a/paddlespeech/server/conf/tts/tts.yaml b/paddlespeech/server/conf/tts/tts.yaml
deleted file mode 100644
index 19207f0b0..000000000
--- a/paddlespeech/server/conf/tts/tts.yaml
+++ /dev/null
@@ -1,32 +0,0 @@
-# This is the parameter configuration file for TTS server.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
-# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
-# 'fastspeech2_vctk']
-##################################################################
-am: 'fastspeech2_csmsc'
-am_config:
-am_ckpt:
-am_stat:
-phones_dict:
-tones_dict:
-speaker_dict:
-spk_id: 0
-
-##################################################################
-# VOCODER SETTING #
-# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
-# 'pwgan_vctk', 'mb_melgan_csmsc']
-##################################################################
-voc: 'pwgan_csmsc'
-voc_config:
-voc_ckpt:
-voc_stat:
-
-##################################################################
-# OTHERS #
-##################################################################
-lang: 'zh'
-device: # set 'gpu:id' or 'cpu'
diff --git a/paddlespeech/server/conf/tts/tts_pd.yaml b/paddlespeech/server/conf/tts/tts_pd.yaml
deleted file mode 100644
index e27b9665b..000000000
--- a/paddlespeech/server/conf/tts/tts_pd.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# This is the parameter configuration file for TTS server.
-# These are the static models that support paddle inference.
-
-##################################################################
-# ACOUSTIC MODEL SETTING #
-# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
-##################################################################
-am: 'fastspeech2_csmsc'
-am_model: # the pdmodel file of your am static model (XX.pdmodel)
-am_params: # the pdiparams file of your am static model (XX.pdipparams)
-am_sample_rate: 24000
-phones_dict:
-tones_dict:
-speaker_dict:
-spk_id: 0
-
-am_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-
-##################################################################
-# VOCODER SETTING #
-# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
-##################################################################
-voc: 'pwgan_csmsc'
-voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
-voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
-voc_sample_rate: 24000
-
-voc_predictor_conf:
- device: # set 'gpu:id' or 'cpu'
- switch_ir_optim: True
- glog_info: False # True -> print glog
- summary: True # False -> do not show predictor config
-
-##################################################################
-# OTHERS #
-##################################################################
-lang: 'zh'
diff --git a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py
index cb973e924..1925bf1d6 100644
--- a/paddlespeech/server/engine/asr/paddleinference/asr_engine.py
+++ b/paddlespeech/server/engine/asr/paddleinference/asr_engine.py
@@ -26,7 +26,6 @@ from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
-from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
@@ -184,7 +183,7 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
- def init(self, config_file: str) -> bool:
+ def init(self, config: dict) -> bool:
"""init engine resource
Args:
@@ -196,7 +195,7 @@ class ASREngine(BaseEngine):
self.input = None
self.output = None
self.executor = ASRServerExecutor()
- self.config = get_config(config_file)
+ self.config = config
self.executor._init_from_path(
model_type=self.config.model_type,
diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py
index 1e2c5cc27..e76c49a79 100644
--- a/paddlespeech/server/engine/asr/python/asr_engine.py
+++ b/paddlespeech/server/engine/asr/python/asr_engine.py
@@ -19,7 +19,6 @@ import paddle
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.base_engine import BaseEngine
-from paddlespeech.server.utils.config import get_config
__all__ = ['ASREngine']
@@ -40,7 +39,7 @@ class ASREngine(BaseEngine):
def __init__(self):
super(ASREngine, self).__init__()
- def init(self, config_file: str) -> bool:
+ def init(self, config: dict) -> bool:
"""init engine resource
Args:
@@ -52,8 +51,7 @@ class ASREngine(BaseEngine):
self.input = None
self.output = None
self.executor = ASRServerExecutor()
-
- self.config = get_config(config_file)
+ self.config = config
try:
if self.config.device:
self.device = self.config.device
diff --git a/paddlespeech/server/engine/cls/__init__.py b/paddlespeech/server/engine/cls/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/server/engine/cls/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/server/engine/cls/paddleinference/__init__.py b/paddlespeech/server/engine/cls/paddleinference/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/server/engine/cls/paddleinference/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/server/engine/cls/paddleinference/cls_engine.py b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py
new file mode 100644
index 000000000..3982effd9
--- /dev/null
+++ b/paddlespeech/server/engine/cls/paddleinference/cls_engine.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import io
+import os
+import time
+from typing import Optional
+
+import numpy as np
+import paddle
+import yaml
+
+from paddlespeech.cli.cls.infer import CLSExecutor
+from paddlespeech.cli.log import logger
+from paddlespeech.cli.utils import download_and_decompress
+from paddlespeech.cli.utils import MODEL_HOME
+from paddlespeech.server.engine.base_engine import BaseEngine
+from paddlespeech.server.utils.paddle_predictor import init_predictor
+from paddlespeech.server.utils.paddle_predictor import run_model
+
+__all__ = ['CLSEngine']
+
+pretrained_models = {
+ "panns_cnn6-32k": {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
+ 'md5':
+ 'da087c31046d23281d8ec5188c1967da',
+ 'cfg_path':
+ 'panns.yaml',
+ 'model_path':
+ 'inference.pdmodel',
+ 'params_path':
+ 'inference.pdiparams',
+ 'label_file':
+ 'audioset_labels.txt',
+ },
+ "panns_cnn10-32k": {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
+ 'md5':
+ '5460cc6eafbfaf0f261cc75b90284ae1',
+ 'cfg_path':
+ 'panns.yaml',
+ 'model_path':
+ 'inference.pdmodel',
+ 'params_path':
+ 'inference.pdiparams',
+ 'label_file':
+ 'audioset_labels.txt',
+ },
+ "panns_cnn14-32k": {
+ 'url':
+ 'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
+ 'md5':
+ 'ccc80b194821274da79466862b2ab00f',
+ 'cfg_path':
+ 'panns.yaml',
+ 'model_path':
+ 'inference.pdmodel',
+ 'params_path':
+ 'inference.pdiparams',
+ 'label_file':
+ 'audioset_labels.txt',
+ },
+}
+
+
+class CLSServerExecutor(CLSExecutor):
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def _get_pretrained_path(self, tag: str) -> os.PathLike:
+ """
+ Download and returns pretrained resources path of current task.
+ """
+ support_models = list(pretrained_models.keys())
+ assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
+ tag, '\n\t\t'.join(support_models))
+
+ res_path = os.path.join(MODEL_HOME, tag)
+ decompressed_path = download_and_decompress(pretrained_models[tag],
+ res_path)
+ decompressed_path = os.path.abspath(decompressed_path)
+ logger.info(
+ 'Use pretrained model stored in: {}'.format(decompressed_path))
+
+ return decompressed_path
+
+ def _init_from_path(
+ self,
+ model_type: str='panns_cnn14',
+ cfg_path: Optional[os.PathLike]=None,
+ model_path: Optional[os.PathLike]=None,
+ params_path: Optional[os.PathLike]=None,
+ label_file: Optional[os.PathLike]=None,
+ predictor_conf: dict=None, ):
+ """
+ Init model and other resources from a specific path.
+ """
+
+ if cfg_path is None or model_path is None or params_path is None or label_file is None:
+ tag = model_type + '-' + '32k'
+ self.res_path = self._get_pretrained_path(tag)
+ self.cfg_path = os.path.join(self.res_path,
+ pretrained_models[tag]['cfg_path'])
+ self.model_path = os.path.join(self.res_path,
+ pretrained_models[tag]['model_path'])
+ self.params_path = os.path.join(
+ self.res_path, pretrained_models[tag]['params_path'])
+ self.label_file = os.path.join(self.res_path,
+ pretrained_models[tag]['label_file'])
+ else:
+ self.cfg_path = os.path.abspath(cfg_path)
+ self.model_path = os.path.abspath(model_path)
+ self.params_path = os.path.abspath(params_path)
+ self.label_file = os.path.abspath(label_file)
+
+ logger.info(self.cfg_path)
+ logger.info(self.model_path)
+ logger.info(self.params_path)
+ logger.info(self.label_file)
+
+ # config
+ with open(self.cfg_path, 'r') as f:
+ self._conf = yaml.safe_load(f)
+ logger.info("Read cfg file successfully.")
+
+ # labels
+ self._label_list = []
+ with open(self.label_file, 'r') as f:
+ for line in f:
+ self._label_list.append(line.strip())
+ logger.info("Read label file successfully.")
+
+ # Create predictor
+ self.predictor_conf = predictor_conf
+ self.predictor = init_predictor(
+ model_file=self.model_path,
+ params_file=self.params_path,
+ predictor_conf=self.predictor_conf)
+ logger.info("Create predictor successfully.")
+
+ @paddle.no_grad()
+ def infer(self):
+ """
+ Model inference and result stored in self.output.
+ """
+ output = run_model(self.predictor, [self._inputs['feats'].numpy()])
+ self._outputs['logits'] = output[0]
+
+
+class CLSEngine(BaseEngine):
+ """CLS server engine
+
+ Args:
+ metaclass: Defaults to Singleton.
+ """
+
+ def __init__(self):
+ super(CLSEngine, self).__init__()
+
+ def init(self, config: dict) -> bool:
+ """init engine resource
+
+ Args:
+ config_file (str): config file
+
+ Returns:
+ bool: init failed or success
+ """
+ self.executor = CLSServerExecutor()
+ self.config = config
+ self.executor._init_from_path(
+ self.config.model_type, self.config.cfg_path,
+ self.config.model_path, self.config.params_path,
+ self.config.label_file, self.config.predictor_conf)
+
+ logger.info("Initialize CLS server engine successfully.")
+ return True
+
+ def run(self, audio_data):
+ """engine run
+
+ Args:
+ audio_data (bytes): base64.b64decode
+ """
+
+ self.executor.preprocess(io.BytesIO(audio_data))
+ st = time.time()
+ self.executor.infer()
+ infer_time = time.time() - st
+
+ logger.info("inference time: {}".format(infer_time))
+ logger.info("cls engine type: inference")
+
+ def postprocess(self, topk: int):
+ """postprocess
+ """
+ assert topk <= len(self.executor._label_list
+ ), 'Value of topk is larger than number of labels.'
+
+ result = np.squeeze(self.executor._outputs['logits'], axis=0)
+ topk_idx = (-result).argsort()[:topk]
+ topk_results = []
+ for idx in topk_idx:
+ res = {}
+ label, score = self.executor._label_list[idx], result[idx]
+ res['class_name'] = label
+ res['prob'] = score
+ topk_results.append(res)
+
+ return topk_results
diff --git a/paddlespeech/server/engine/cls/python/__init__.py b/paddlespeech/server/engine/cls/python/__init__.py
new file mode 100644
index 000000000..97043fd7b
--- /dev/null
+++ b/paddlespeech/server/engine/cls/python/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlespeech/server/engine/cls/python/cls_engine.py b/paddlespeech/server/engine/cls/python/cls_engine.py
new file mode 100644
index 000000000..1a975b0a0
--- /dev/null
+++ b/paddlespeech/server/engine/cls/python/cls_engine.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import io
+import time
+from typing import List
+
+import paddle
+
+from paddlespeech.cli.cls.infer import CLSExecutor
+from paddlespeech.cli.log import logger
+from paddlespeech.server.engine.base_engine import BaseEngine
+
+__all__ = ['CLSEngine']
+
+
+class CLSServerExecutor(CLSExecutor):
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def get_topk_results(self, topk: int) -> List:
+ assert topk <= len(
+ self._label_list), 'Value of topk is larger than number of labels.'
+
+ result = self._outputs['logits'].squeeze(0).numpy()
+ topk_idx = (-result).argsort()[:topk]
+ res = {}
+ topk_results = []
+ for idx in topk_idx:
+ label, score = self._label_list[idx], result[idx]
+ res['class'] = label
+ res['prob'] = score
+ topk_results.append(res)
+ return topk_results
+
+
+class CLSEngine(BaseEngine):
+ """CLS server engine
+
+ Args:
+ metaclass: Defaults to Singleton.
+ """
+
+ def __init__(self):
+ super(CLSEngine, self).__init__()
+
+ def init(self, config: dict) -> bool:
+ """init engine resource
+
+ Args:
+ config_file (str): config file
+
+ Returns:
+ bool: init failed or success
+ """
+ self.input = None
+ self.output = None
+ self.executor = CLSServerExecutor()
+ self.config = config
+ try:
+ if self.config.device:
+ self.device = self.config.device
+ else:
+ self.device = paddle.get_device()
+ paddle.set_device(self.device)
+ except BaseException:
+ logger.error(
+ "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
+ )
+
+ try:
+ self.executor._init_from_path(
+ self.config.model, self.config.cfg_path, self.config.ckpt_path,
+ self.config.label_file)
+ except BaseException:
+ logger.error("Initialize CLS server engine Failed.")
+ return False
+
+ logger.info("Initialize CLS server engine successfully on device: %s." %
+ (self.device))
+ return True
+
+ def run(self, audio_data):
+ """engine run
+
+ Args:
+ audio_data (bytes): base64.b64decode
+ """
+ self.executor.preprocess(io.BytesIO(audio_data))
+ st = time.time()
+ self.executor.infer()
+ infer_time = time.time() - st
+
+ logger.info("inference time: {}".format(infer_time))
+ logger.info("cls engine type: python")
+
+ def postprocess(self, topk: int):
+ """postprocess
+ """
+ assert topk <= len(self.executor._label_list
+ ), 'Value of topk is larger than number of labels.'
+
+ result = self.executor._outputs['logits'].squeeze(0).numpy()
+ topk_idx = (-result).argsort()[:topk]
+ topk_results = []
+ for idx in topk_idx:
+ res = {}
+ label, score = self.executor._label_list[idx], result[idx]
+ res['class_name'] = label
+ res['prob'] = score
+ topk_results.append(res)
+
+ return topk_results
diff --git a/paddlespeech/server/engine/engine_factory.py b/paddlespeech/server/engine/engine_factory.py
index 546541edf..c39c44cae 100644
--- a/paddlespeech/server/engine/engine_factory.py
+++ b/paddlespeech/server/engine/engine_factory.py
@@ -31,5 +31,11 @@ class EngineFactory(object):
elif engine_name == 'tts' and engine_type == 'python':
from paddlespeech.server.engine.tts.python.tts_engine import TTSEngine
return TTSEngine()
+ elif engine_name == 'cls' and engine_type == 'inference':
+ from paddlespeech.server.engine.cls.paddleinference.cls_engine import CLSEngine
+ return CLSEngine()
+ elif engine_name == 'cls' and engine_type == 'python':
+ from paddlespeech.server.engine.cls.python.cls_engine import CLSEngine
+ return CLSEngine()
else:
return None
diff --git a/paddlespeech/server/engine/engine_pool.py b/paddlespeech/server/engine/engine_pool.py
index f6a4d2aab..9de73567e 100644
--- a/paddlespeech/server/engine/engine_pool.py
+++ b/paddlespeech/server/engine/engine_pool.py
@@ -28,11 +28,13 @@ def init_engine_pool(config) -> bool:
""" Init engine pool
"""
global ENGINE_POOL
- for engine in config.engine_backend:
+
+ for engine_and_type in config.engine_list:
+ engine = engine_and_type.split("_")[0]
+ engine_type = engine_and_type.split("_")[1]
ENGINE_POOL[engine] = EngineFactory.get_engine(
- engine_name=engine, engine_type=config.engine_type[engine])
- if not ENGINE_POOL[engine].init(
- config_file=config.engine_backend[engine]):
+ engine_name=engine, engine_type=engine_type)
+ if not ENGINE_POOL[engine].init(config=config[engine_and_type]):
return False
return True
diff --git a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py
index 5955c1a21..db8813ba9 100644
--- a/paddlespeech/server/engine/tts/paddleinference/tts_engine.py
+++ b/paddlespeech/server/engine/tts/paddleinference/tts_engine.py
@@ -29,7 +29,6 @@ from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
-from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.utils.paddle_predictor import init_predictor
@@ -251,27 +250,21 @@ class TTSServerExecutor(TTSExecutor):
self.frontend = English(phone_vocab_path=self.phones_dict)
logger.info("frontend done!")
- try:
- # am predictor
- self.am_predictor_conf = am_predictor_conf
- self.am_predictor = init_predictor(
- model_file=self.am_model,
- params_file=self.am_params,
- predictor_conf=self.am_predictor_conf)
- logger.info("Create AM predictor successfully.")
- except BaseException:
- logger.error("Failed to create AM predictor.")
-
- try:
- # voc predictor
- self.voc_predictor_conf = voc_predictor_conf
- self.voc_predictor = init_predictor(
- model_file=self.voc_model,
- params_file=self.voc_params,
- predictor_conf=self.voc_predictor_conf)
- logger.info("Create Vocoder predictor successfully.")
- except BaseException:
- logger.error("Failed to create Vocoder predictor.")
+ # Create am predictor
+ self.am_predictor_conf = am_predictor_conf
+ self.am_predictor = init_predictor(
+ model_file=self.am_model,
+ params_file=self.am_params,
+ predictor_conf=self.am_predictor_conf)
+ logger.info("Create AM predictor successfully.")
+
+ # Create voc predictor
+ self.voc_predictor_conf = voc_predictor_conf
+ self.voc_predictor = init_predictor(
+ model_file=self.voc_model,
+ params_file=self.voc_params,
+ predictor_conf=self.voc_predictor_conf)
+ logger.info("Create Vocoder predictor successfully.")
@paddle.no_grad()
def infer(self,
@@ -357,30 +350,25 @@ class TTSEngine(BaseEngine):
"""
super(TTSEngine, self).__init__()
- def init(self, config_file: str) -> bool:
+ def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
- try:
- self.config = get_config(config_file)
- self.executor._init_from_path(
- am=self.config.am,
- am_model=self.config.am_model,
- am_params=self.config.am_params,
- am_sample_rate=self.config.am_sample_rate,
- phones_dict=self.config.phones_dict,
- tones_dict=self.config.tones_dict,
- speaker_dict=self.config.speaker_dict,
- voc=self.config.voc,
- voc_model=self.config.voc_model,
- voc_params=self.config.voc_params,
- voc_sample_rate=self.config.voc_sample_rate,
- lang=self.config.lang,
- am_predictor_conf=self.config.am_predictor_conf,
- voc_predictor_conf=self.config.voc_predictor_conf, )
-
- except BaseException:
- logger.error("Initialize TTS server engine Failed.")
- return False
+ self.config = config
+ self.executor._init_from_path(
+ am=self.config.am,
+ am_model=self.config.am_model,
+ am_params=self.config.am_params,
+ am_sample_rate=self.config.am_sample_rate,
+ phones_dict=self.config.phones_dict,
+ tones_dict=self.config.tones_dict,
+ speaker_dict=self.config.speaker_dict,
+ voc=self.config.voc,
+ voc_model=self.config.voc_model,
+ voc_params=self.config.voc_params,
+ voc_sample_rate=self.config.voc_sample_rate,
+ lang=self.config.lang,
+ am_predictor_conf=self.config.am_predictor_conf,
+ voc_predictor_conf=self.config.voc_predictor_conf, )
logger.info("Initialize TTS server engine successfully.")
return True
@@ -543,4 +531,4 @@ class TTSEngine(BaseEngine):
postprocess_time))
logger.info("RTF: {}".format(rtf))
- return lang, target_sample_rate, wav_base64
+ return lang, target_sample_rate, duration, wav_base64
diff --git a/paddlespeech/server/engine/tts/python/tts_engine.py b/paddlespeech/server/engine/tts/python/tts_engine.py
index 7dd576699..f153f60b9 100644
--- a/paddlespeech/server/engine/tts/python/tts_engine.py
+++ b/paddlespeech/server/engine/tts/python/tts_engine.py
@@ -25,7 +25,6 @@ from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
-from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
@@ -50,11 +49,11 @@ class TTSEngine(BaseEngine):
"""
super(TTSEngine, self).__init__()
- def init(self, config_file: str) -> bool:
+ def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
try:
- self.config = get_config(config_file)
+ self.config = config
if self.config.device:
self.device = self.config.device
else:
@@ -251,4 +250,4 @@ class TTSEngine(BaseEngine):
logger.info("RTF: {}".format(rtf))
logger.info("device: {}".format(self.device))
- return lang, target_sample_rate, wav_base64
+ return lang, target_sample_rate, duration, wav_base64
diff --git a/paddlespeech/server/restful/api.py b/paddlespeech/server/restful/api.py
index 2d69dee87..3f91a03b6 100644
--- a/paddlespeech/server/restful/api.py
+++ b/paddlespeech/server/restful/api.py
@@ -16,6 +16,7 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.server.restful.asr_api import router as asr_router
+from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.tts_api import router as tts_router
_router = APIRouter()
@@ -25,7 +26,7 @@ def setup_router(api_list: List):
"""setup router for fastapi
Args:
- api_list (List): [asr, tts]
+ api_list (List): [asr, tts, cls]
Returns:
APIRouter
@@ -35,6 +36,8 @@ def setup_router(api_list: List):
_router.include_router(asr_router)
elif api_name == 'tts':
_router.include_router(tts_router)
+ elif api_name == 'cls':
+ _router.include_router(cls_router)
else:
pass
diff --git a/paddlespeech/server/restful/cls_api.py b/paddlespeech/server/restful/cls_api.py
new file mode 100644
index 000000000..306d9ca9c
--- /dev/null
+++ b/paddlespeech/server/restful/cls_api.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import base64
+import traceback
+from typing import Union
+
+from fastapi import APIRouter
+
+from paddlespeech.server.engine.engine_pool import get_engine_pool
+from paddlespeech.server.restful.request import CLSRequest
+from paddlespeech.server.restful.response import CLSResponse
+from paddlespeech.server.restful.response import ErrorResponse
+from paddlespeech.server.utils.errors import ErrorCode
+from paddlespeech.server.utils.errors import failed_response
+from paddlespeech.server.utils.exception import ServerBaseException
+
+router = APIRouter()
+
+
+@router.get('/paddlespeech/cls/help')
+def help():
+ """help
+
+ Returns:
+ json: [description]
+ """
+ response = {
+ "success": "True",
+ "code": 200,
+ "message": {
+ "global": "success"
+ },
+ "result": {
+ "description": "cls server",
+ "input": "base64 string of wavfile",
+ "output": "classification result"
+ }
+ }
+ return response
+
+
+@router.post(
+ "/paddlespeech/cls", response_model=Union[CLSResponse, ErrorResponse])
+def cls(request_body: CLSRequest):
+ """cls api
+
+ Args:
+ request_body (CLSRequest): [description]
+
+ Returns:
+ json: [description]
+ """
+ try:
+ audio_data = base64.b64decode(request_body.audio)
+
+ # get single engine from engine pool
+ engine_pool = get_engine_pool()
+ cls_engine = engine_pool['cls']
+
+ cls_engine.run(audio_data)
+ cls_results = cls_engine.postprocess(request_body.topk)
+
+ response = {
+ "success": True,
+ "code": 200,
+ "message": {
+ "description": "success"
+ },
+ "result": {
+ "topk": request_body.topk,
+ "results": cls_results
+ }
+ }
+
+ except ServerBaseException as e:
+ response = failed_response(e.error_code, e.msg)
+ except BaseException:
+ response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
+ traceback.print_exc()
+
+ return response
diff --git a/paddlespeech/server/restful/request.py b/paddlespeech/server/restful/request.py
index 289088019..dbac9dac8 100644
--- a/paddlespeech/server/restful/request.py
+++ b/paddlespeech/server/restful/request.py
@@ -15,7 +15,7 @@ from typing import Optional
from pydantic import BaseModel
-__all__ = ['ASRRequest', 'TTSRequest']
+__all__ = ['ASRRequest', 'TTSRequest', 'CLSRequest']
#****************************************************************************************/
@@ -63,3 +63,18 @@ class TTSRequest(BaseModel):
volume: float = 1.0
sample_rate: int = 0
save_path: str = None
+
+
+#****************************************************************************************/
+#************************************ CLS request ***************************************/
+#****************************************************************************************/
+class CLSRequest(BaseModel):
+ """
+ request body example
+ {
+ "audio": "exSI6ICJlbiIsCgkgICAgInBvc2l0aW9uIjogImZhbHNlIgoJf...",
+ "topk": 1
+ }
+ """
+ audio: str
+ topk: int = 1
diff --git a/paddlespeech/server/restful/response.py b/paddlespeech/server/restful/response.py
index 4e18ee0d7..a2a207e4f 100644
--- a/paddlespeech/server/restful/response.py
+++ b/paddlespeech/server/restful/response.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+
from pydantic import BaseModel
-__all__ = ['ASRResponse', 'TTSResponse']
+__all__ = ['ASRResponse', 'TTSResponse', 'CLSResponse']
class Message(BaseModel):
@@ -52,10 +54,11 @@ class ASRResponse(BaseModel):
#****************************************************************************************/
class TTSResult(BaseModel):
lang: str = "zh"
- sample_rate: int
spk_id: int = 0
speed: float = 1.0
volume: float = 1.0
+ sample_rate: int
+ duration: float
save_path: str = None
audio: str
@@ -71,9 +74,11 @@ class TTSResponse(BaseModel):
},
"result": {
"lang": "zh",
- "sample_rate": 24000,
+ "spk_id": 0,
"speed": 1.0,
"volume": 1.0,
+ "sample_rate": 24000,
+ "duration": 3.6125,
"audio": "LTI1OTIuNjI1OTUwMzQsOTk2OS41NDk4...",
"save_path": "./tts.wav"
}
@@ -85,6 +90,45 @@ class TTSResponse(BaseModel):
result: TTSResult
+#****************************************************************************************/
+#************************************ CLS response **************************************/
+#****************************************************************************************/
+class CLSResults(BaseModel):
+ class_name: str
+ prob: float
+
+
+class CLSResult(BaseModel):
+ topk: int
+ results: List[CLSResults]
+
+
+class CLSResponse(BaseModel):
+ """
+ response example
+ {
+ "success": true,
+ "code": 0,
+ "message": {
+ "description": "success"
+ },
+ "result": {
+ topk: 1
+ results: [
+ {
+ "class":"Speech",
+ "prob": 0.9027184844017029
+ }
+ ]
+ }
+ }
+ """
+ success: bool
+ code: int
+ message: Message
+ result: CLSResult
+
+
#****************************************************************************************/
#********************************** Error response **************************************/
#****************************************************************************************/
diff --git a/paddlespeech/server/restful/tts_api.py b/paddlespeech/server/restful/tts_api.py
index 0af0f6d07..4e9bbe23e 100644
--- a/paddlespeech/server/restful/tts_api.py
+++ b/paddlespeech/server/restful/tts_api.py
@@ -98,7 +98,7 @@ def tts(request_body: TTSRequest):
tts_engine = engine_pool['tts']
logger.info("Get tts engine successfully.")
- lang, target_sample_rate, wav_base64 = tts_engine.run(
+ lang, target_sample_rate, duration, wav_base64 = tts_engine.run(
text, spk_id, speed, volume, sample_rate, save_path)
response = {
@@ -113,6 +113,7 @@ def tts(request_body: TTSRequest):
"speed": speed,
"volume": volume,
"sample_rate": target_sample_rate,
+ "duration": duration,
"save_path": save_path,
"audio": wav_base64
}
diff --git a/paddlespeech/server/utils/paddle_predictor.py b/paddlespeech/server/utils/paddle_predictor.py
index 4035d48d8..16653cf37 100644
--- a/paddlespeech/server/utils/paddle_predictor.py
+++ b/paddlespeech/server/utils/paddle_predictor.py
@@ -35,10 +35,12 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
Returns:
predictor (PaddleInferPredictor): created predictor
"""
-
if model_dir is not None:
+ assert os.path.isdir(model_dir), 'Please check model dir.'
config = Config(args.model_dir)
else:
+ assert os.path.isfile(model_file) and os.path.isfile(
+ params_file), 'Please check model and parameter files.'
config = Config(model_file, params_file)
# set device
@@ -66,7 +68,6 @@ def init_predictor(model_dir: Optional[os.PathLike]=None,
config.enable_memory_optim()
predictor = create_predictor(config)
-
return predictor
@@ -84,10 +85,8 @@ def run_model(predictor, input: List) -> List:
for i, name in enumerate(input_names):
input_handle = predictor.get_input_handle(name)
input_handle.copy_from_cpu(input[i])
-
# do the inference
predictor.run()
-
results = []
# get out data from output tensor
output_names = predictor.get_output_names()
diff --git a/paddlespeech/t2s/exps/csmsc_test.txt b/paddlespeech/t2s/exps/csmsc_test.txt
new file mode 100644
index 000000000..d8cf367cd
--- /dev/null
+++ b/paddlespeech/t2s/exps/csmsc_test.txt
@@ -0,0 +1,100 @@
+009901 昨日,这名伤者与医生全部被警方依法刑事拘留。
+009902 钱伟长想到上海来办学校是经过深思熟虑的。
+009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。
+009904 李述德在离开之前,只说了一句柱驼杀父亲了。
+009905 这种车票和保险单捆绑出售属于重复性购买。
+009906 戴佩妮的男友西米露接唱情歌,让她非常开心。
+009907 观大势,谋大局,出大策始终是该院的办院方针。
+009908 他们骑着摩托回家,正好为农忙时的父母帮忙。
+009909 但是因为还没到退休年龄,只能掰着指头捱日子。
+009910 这几天雨水不断,人们恨不得待在家里不出门。
+009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。
+009912 藤村此番发言可能是为了凸显野田的领导能力。
+009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。
+009914 南海海域综合补给基地码头项目正在论证中。
+009915 也就是说今晚成都市民极有可能再次看到飘雪。
+009916 随着天气转热,各地的游泳场所开始人头攒动。
+009917 更让徐先生纳闷的是,房客的手机也打不通了。
+009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。
+009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。
+009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。
+009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客?
+009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。
+009923 会有很丰富的东西留下来,说都说不完。
+009924 这句话像从天而降,吓得四周一片寂静。
+009925 记者所在的是受害人家属所在的右区。
+009926 不管哈大爷去哪,它都一步不离地跟着。
+009927 大家抬头望去,一只老鼠正趴在吊顶上。
+009928 我决定过年就辞职,接手我爸的废品站!
+009929 最终,中国男子乒乓球队获得此奖项。
+009930 防汛抗旱两手抓,抗旱相对抓的不够。
+009931 图们江下游地区开发开放的进展如何?
+009932 这要求中国必须有一个坚强的政党领导。
+009933 再说,关于利益上的事俺俩都不好开口。
+009934 明代瓦剌,鞑靼入侵明境也是通过此地。
+009935 咪咪舔着孩子,把它身上的毛舔干净。
+009936 是否这次的国标修订被大企业绑架了?
+009937 判决后,姚某妻子胡某不服,提起上诉。
+009938 由此可以看出邯钢的经济效益来自何处。
+009939 琳达说,是瑜伽改变了她和马儿的生活。
+009940 楼下的保安告诉记者,这里不租也不卖。
+009941 习近平说,中斯两国人民传统友谊深厚。
+009942 传闻越来越多,后来连老汉儿自己都怕了。
+009943 我怒吼一声冲上去,举起砖头砸了过去。
+009944 我现在还不会,这就回去问问发明我的人。
+009945 显然,洛阳性奴案不具备上述两个前提。
+009946 另外,杰克逊有文唇线,眼线,眉毛的动作。
+009947 昨晚,华西都市报记者电话采访了尹琪。
+009948 涅拉季科未透露这些航空公司的名称。
+009949 从运行轨迹上来说,它也不可能是星星。
+009950 目前看,如果继续加息也存在两难问题。
+009951 曾宝仪在节目录制现场大爆观众糗事。
+009952 但任凭周某怎么叫,男子仍酣睡不醒。
+009953 老大爷说,小子,你挡我财路了,知道不?
+009954 没料到,闯下大头佛的阿伟还不知悔改。
+009955 卡扎菲部落式统治已遭遇部落内讧。
+009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。
+009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。
+009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。
+009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。
+009960 因为,我们所发出的力量必会因难度加大而减弱。
+009961 发生事故的楼梯拐角处仍可看到血迹。
+009962 想过进公安,可能身高不够,老汉儿也不让我进去。
+009963 路上关卡很多,为了方便撤离,只好轻装前进。
+009964 原来比尔盖茨就是美国微软公司联合创始人呀。
+009965 之后他们一家三口将与双方父母往峇里岛旅游。
+009966 谢谢总理,也感谢广大网友的参与,我们明年再见。
+009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。
+009968 我会打开邮件,你可以从那里继续。
+009969 美方对近期东海局势表示关切。
+009970 据悉,奥巴马一家人对这座冬季白宫极为满意。
+009971 打扫完你会很有成就感的,试一试,你就信了。
+009972 诺曼站在滑板车上,各就各位,准备出发啦!
+009973 塔河的寒夜,气温降到了零下三十多摄氏度。
+009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。
+009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。
+009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。
+009977 直至公元前一万一千年,它又再次出现。
+009978 尽量少玩电脑,少看电视,少打游戏。
+009979 从五到七,前后也就是六个月的时间。
+009980 一进咖啡店,他就遇见一张熟悉的脸。
+009981 好在众弟兄看到了把她追了回来。
+009982 有一个人说,哥们儿我们跑过它才能活。
+009983 捅了她以后,模糊记得她没咋动了。
+009984 从小到大,葛启义没有收到过压岁钱。
+009985 舞台下的你会对舞台上的你说什么?
+009986 但考生普遍认为,试题的怪多过难。
+009987 我希望每个人都能够尊重我们的隐私。
+009988 漫天的红霞使劲给两人增添气氛。
+009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声!
+009990 该车将三人撞倒后,在大雾中逃窜。
+009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。
+009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊!
+009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。
+009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。
+009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。
+009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。
+009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递?
+009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁?
+009999 因华国锋肖鸡,墓地设计根据其属相设计。
+010000 在狱中,张明宝悔恨交加,写了一份忏悔书。
diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py
index 26d7e2c08..1188ddfb1 100644
--- a/paddlespeech/t2s/exps/inference.py
+++ b/paddlespeech/t2s/exps/inference.py
@@ -17,13 +17,92 @@ from pathlib import Path
import numpy
import soundfile as sf
from paddle import inference
-
-from paddlespeech.t2s.frontend import English
-from paddlespeech.t2s.frontend.zh_frontend import Frontend
+from timer import timer
+
+from paddlespeech.t2s.exps.syn_utils import get_frontend
+from paddlespeech.t2s.exps.syn_utils import get_sentences
+from paddlespeech.t2s.utils import str2bool
+
+
+def get_predictor(args, filed='am'):
+ full_name = ''
+ if filed == 'am':
+ full_name = args.am
+ elif filed == 'voc':
+ full_name = args.voc
+ model_name = full_name[:full_name.rindex('_')]
+ config = inference.Config(
+ str(Path(args.inference_dir) / (full_name + ".pdmodel")),
+ str(Path(args.inference_dir) / (full_name + ".pdiparams")))
+ if args.device == "gpu":
+ config.enable_use_gpu(100, 0)
+ elif args.device == "cpu":
+ config.disable_gpu()
+ # This line must be commented for fastspeech2, if not, it will OOM
+ if model_name != 'fastspeech2':
+ config.enable_memory_optim()
+ predictor = inference.create_predictor(config)
+ return predictor
-# only inference for models trained with csmsc now
-def main():
+def get_am_output(args, am_predictor, frontend, merge_sentences, input):
+ am_name = args.am[:args.am.rindex('_')]
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+ am_input_names = am_predictor.get_input_names()
+ get_tone_ids = False
+ get_spk_id = False
+ if am_name == 'speedyspeech':
+ get_tone_ids = True
+ if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
+ get_spk_id = True
+ spk_id = numpy.array([args.spk_id])
+ if args.lang == 'zh':
+ input_ids = frontend.get_input_ids(
+ input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
+ phone_ids = input_ids["phone_ids"]
+ elif args.lang == 'en':
+ input_ids = frontend.get_input_ids(
+ input, merge_sentences=merge_sentences)
+ phone_ids = input_ids["phone_ids"]
+ else:
+ print("lang should in {'zh', 'en'}!")
+
+ if get_tone_ids:
+ tone_ids = input_ids["tone_ids"]
+ tones = tone_ids[0].numpy()
+ tones_handle = am_predictor.get_input_handle(am_input_names[1])
+ tones_handle.reshape(tones.shape)
+ tones_handle.copy_from_cpu(tones)
+ if get_spk_id:
+ spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
+ spk_id_handle.reshape(spk_id.shape)
+ spk_id_handle.copy_from_cpu(spk_id)
+ phones = phone_ids[0].numpy()
+ phones_handle = am_predictor.get_input_handle(am_input_names[0])
+ phones_handle.reshape(phones.shape)
+ phones_handle.copy_from_cpu(phones)
+
+ am_predictor.run()
+ am_output_names = am_predictor.get_output_names()
+ am_output_handle = am_predictor.get_output_handle(am_output_names[0])
+ am_output_data = am_output_handle.copy_to_cpu()
+ return am_output_data
+
+
+def get_voc_output(args, voc_predictor, input):
+ voc_input_names = voc_predictor.get_input_names()
+ mel_handle = voc_predictor.get_input_handle(voc_input_names[0])
+ mel_handle.reshape(input.shape)
+ mel_handle.copy_from_cpu(input)
+
+ voc_predictor.run()
+ voc_output_names = voc_predictor.get_output_names()
+ voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0])
+ wav = voc_output_handle.copy_to_cpu()
+ return wav
+
+
+def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Infernce with speedyspeech & parallel wavegan.")
# acoustic model
@@ -70,113 +149,97 @@ def main():
parser.add_argument(
"--inference_dir", type=str, help="dir to save inference models")
parser.add_argument("--output_dir", type=str, help="output dir")
+ # inference
+ parser.add_argument(
+ "--use_trt",
+ type=str2bool,
+ default=False,
+ help="Whether to use inference engin TensorRT.", )
+ parser.add_argument(
+ "--int8",
+ type=str2bool,
+ default=False,
+ help="Whether to use int8 inference.", )
+ parser.add_argument(
+ "--fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use float16 inference.", )
+ parser.add_argument(
+ "--device",
+ default="gpu",
+ choices=["gpu", "cpu"],
+ help="Device selected for inference.", )
args, _ = parser.parse_known_args()
+ return args
+
+# only inference for models trained with csmsc now
+def main():
+ args = parse_args()
# frontend
- if args.lang == 'zh':
- frontend = Frontend(
- phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
- elif args.lang == 'en':
- frontend = English(phone_vocab_path=args.phones_dict)
- print("frontend done!")
+ frontend = get_frontend(args)
+ # am_predictor
+ am_predictor = get_predictor(args, filed='am')
# model: {model_name}_{dataset}
- am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
- am_config = inference.Config(
- str(Path(args.inference_dir) / (args.am + ".pdmodel")),
- str(Path(args.inference_dir) / (args.am + ".pdiparams")))
- am_config.enable_use_gpu(100, 0)
- # This line must be commented for fastspeech2, if not, it will OOM
- if am_name != 'fastspeech2':
- am_config.enable_memory_optim()
- am_predictor = inference.create_predictor(am_config)
-
- voc_config = inference.Config(
- str(Path(args.inference_dir) / (args.voc + ".pdmodel")),
- str(Path(args.inference_dir) / (args.voc + ".pdiparams")))
- voc_config.enable_use_gpu(100, 0)
- voc_config.enable_memory_optim()
- voc_predictor = inference.create_predictor(voc_config)
+ # voc_predictor
+ voc_predictor = get_predictor(args, filed='voc')
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
- sentences = []
-
- print("in new inference")
-
- # construct dataset for evaluation
- sentences = []
- with open(args.text, 'rt') as f:
- for line in f:
- items = line.strip().split()
- utt_id = items[0]
- if args.lang == 'zh':
- sentence = "".join(items[1:])
- elif args.lang == 'en':
- sentence = " ".join(items[1:])
- sentences.append((utt_id, sentence))
- get_tone_ids = False
- get_spk_id = False
- if am_name == 'speedyspeech':
- get_tone_ids = True
- if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
- get_spk_id = True
- spk_id = numpy.array([args.spk_id])
+ sentences = get_sentences(args)
- am_input_names = am_predictor.get_input_names()
- print("am_input_names:", am_input_names)
merge_sentences = True
+ fs = 24000 if am_dataset != 'ljspeech' else 22050
+ # warmup
+ for utt_id, sentence in sentences[:3]:
+ with timer() as t:
+ am_output_data = get_am_output(
+ args,
+ am_predictor=am_predictor,
+ frontend=frontend,
+ merge_sentences=merge_sentences,
+ input=sentence)
+ wav = get_voc_output(
+ args, voc_predictor=voc_predictor, input=am_output_data)
+ speed = wav.size / t.elapse
+ rtf = fs / speed
+ print(
+ f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
+
+ print("warm up done!")
+
+ N = 0
+ T = 0
for utt_id, sentence in sentences:
- if args.lang == 'zh':
- input_ids = frontend.get_input_ids(
- sentence,
+ with timer() as t:
+ am_output_data = get_am_output(
+ args,
+ am_predictor=am_predictor,
+ frontend=frontend,
merge_sentences=merge_sentences,
- get_tone_ids=get_tone_ids)
- phone_ids = input_ids["phone_ids"]
- elif args.lang == 'en':
- input_ids = frontend.get_input_ids(
- sentence, merge_sentences=merge_sentences)
- phone_ids = input_ids["phone_ids"]
- else:
- print("lang should in {'zh', 'en'}!")
-
- if get_tone_ids:
- tone_ids = input_ids["tone_ids"]
- tones = tone_ids[0].numpy()
- tones_handle = am_predictor.get_input_handle(am_input_names[1])
- tones_handle.reshape(tones.shape)
- tones_handle.copy_from_cpu(tones)
- if get_spk_id:
- spk_id_handle = am_predictor.get_input_handle(am_input_names[1])
- spk_id_handle.reshape(spk_id.shape)
- spk_id_handle.copy_from_cpu(spk_id)
- phones = phone_ids[0].numpy()
- phones_handle = am_predictor.get_input_handle(am_input_names[0])
- phones_handle.reshape(phones.shape)
- phones_handle.copy_from_cpu(phones)
-
- am_predictor.run()
- am_output_names = am_predictor.get_output_names()
- am_output_handle = am_predictor.get_output_handle(am_output_names[0])
- am_output_data = am_output_handle.copy_to_cpu()
-
- voc_input_names = voc_predictor.get_input_names()
- mel_handle = voc_predictor.get_input_handle(voc_input_names[0])
- mel_handle.reshape(am_output_data.shape)
- mel_handle.copy_from_cpu(am_output_data)
-
- voc_predictor.run()
- voc_output_names = voc_predictor.get_output_names()
- voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0])
- wav = voc_output_handle.copy_to_cpu()
+ input=sentence)
+ wav = get_voc_output(
+ args, voc_predictor=voc_predictor, input=am_output_data)
+
+ N += wav.size
+ T += t.elapse
+ speed = wav.size / t.elapse
+ rtf = fs / speed
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
+ print(
+ f"{utt_id}, mel: {am_output_data.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}."
+ )
print(f"{utt_id} done!")
+ print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }")
if __name__ == "__main__":
diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py
new file mode 100644
index 000000000..c52cb3727
--- /dev/null
+++ b/paddlespeech/t2s/exps/syn_utils.py
@@ -0,0 +1,243 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import numpy as np
+import paddle
+from paddle import jit
+from paddle.static import InputSpec
+
+from paddlespeech.s2t.utils.dynamic_import import dynamic_import
+from paddlespeech.t2s.datasets.data_table import DataTable
+from paddlespeech.t2s.frontend import English
+from paddlespeech.t2s.frontend.zh_frontend import Frontend
+from paddlespeech.t2s.modules.normalizer import ZScore
+
+model_alias = {
+ # acoustic model
+ "speedyspeech":
+ "paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
+ "speedyspeech_inference":
+ "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
+ "fastspeech2":
+ "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
+ "fastspeech2_inference":
+ "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
+ "tacotron2":
+ "paddlespeech.t2s.models.tacotron2:Tacotron2",
+ "tacotron2_inference":
+ "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
+ # voc
+ "pwgan":
+ "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
+ "pwgan_inference":
+ "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
+ "mb_melgan":
+ "paddlespeech.t2s.models.melgan:MelGANGenerator",
+ "mb_melgan_inference":
+ "paddlespeech.t2s.models.melgan:MelGANInference",
+ "style_melgan":
+ "paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
+ "style_melgan_inference":
+ "paddlespeech.t2s.models.melgan:StyleMelGANInference",
+ "hifigan":
+ "paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
+ "hifigan_inference":
+ "paddlespeech.t2s.models.hifigan:HiFiGANInference",
+ "wavernn":
+ "paddlespeech.t2s.models.wavernn:WaveRNN",
+ "wavernn_inference":
+ "paddlespeech.t2s.models.wavernn:WaveRNNInference",
+}
+
+
+# input
+def get_sentences(args):
+ # construct dataset for evaluation
+ sentences = []
+ with open(args.text, 'rt') as f:
+ for line in f:
+ items = line.strip().split()
+ utt_id = items[0]
+ if 'lang' in args and args.lang == 'zh':
+ sentence = "".join(items[1:])
+ elif 'lang' in args and args.lang == 'en':
+ sentence = " ".join(items[1:])
+ sentences.append((utt_id, sentence))
+ return sentences
+
+
+def get_test_dataset(args, test_metadata, am_name, am_dataset):
+ if am_name == 'fastspeech2':
+ fields = ["utt_id", "text"]
+ if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
+ print("multiple speaker fastspeech2!")
+ fields += ["spk_id"]
+ elif 'voice_cloning' in args and args.voice_cloning:
+ print("voice cloning!")
+ fields += ["spk_emb"]
+ else:
+ print("single speaker fastspeech2!")
+ elif am_name == 'speedyspeech':
+ fields = ["utt_id", "phones", "tones"]
+ elif am_name == 'tacotron2':
+ fields = ["utt_id", "text"]
+ if 'voice_cloning' in args and args.voice_cloning:
+ print("voice cloning!")
+ fields += ["spk_emb"]
+
+ test_dataset = DataTable(data=test_metadata, fields=fields)
+ return test_dataset
+
+
+# frontend
+def get_frontend(args):
+ if 'lang' in args and args.lang == 'zh':
+ frontend = Frontend(
+ phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
+ elif 'lang' in args and args.lang == 'en':
+ frontend = English(phone_vocab_path=args.phones_dict)
+ else:
+ print("wrong lang!")
+ print("frontend done!")
+ return frontend
+
+
+# dygraph
+def get_am_inference(args, am_config):
+ with open(args.phones_dict, "r") as f:
+ phn_id = [line.strip().split() for line in f.readlines()]
+ vocab_size = len(phn_id)
+ print("vocab_size:", vocab_size)
+
+ tone_size = None
+ if 'tones_dict' in args and args.tones_dict:
+ with open(args.tones_dict, "r") as f:
+ tone_id = [line.strip().split() for line in f.readlines()]
+ tone_size = len(tone_id)
+ print("tone_size:", tone_size)
+
+ spk_num = None
+ if 'speaker_dict' in args and args.speaker_dict:
+ with open(args.speaker_dict, 'rt') as f:
+ spk_id = [line.strip().split() for line in f.readlines()]
+ spk_num = len(spk_id)
+ print("spk_num:", spk_num)
+
+ odim = am_config.n_mels
+ # model: {model_name}_{dataset}
+ am_name = args.am[:args.am.rindex('_')]
+ am_dataset = args.am[args.am.rindex('_') + 1:]
+
+ am_class = dynamic_import(am_name, model_alias)
+ am_inference_class = dynamic_import(am_name + '_inference', model_alias)
+
+ if am_name == 'fastspeech2':
+ am = am_class(
+ idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
+ elif am_name == 'speedyspeech':
+ am = am_class(
+ vocab_size=vocab_size,
+ tone_size=tone_size,
+ spk_num=spk_num,
+ **am_config["model"])
+ elif am_name == 'tacotron2':
+ am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
+
+ am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
+ am.eval()
+ am_mu, am_std = np.load(args.am_stat)
+ am_mu = paddle.to_tensor(am_mu)
+ am_std = paddle.to_tensor(am_std)
+ am_normalizer = ZScore(am_mu, am_std)
+ am_inference = am_inference_class(am_normalizer, am)
+ am_inference.eval()
+ print("acoustic model done!")
+ return am_inference, am_name, am_dataset
+
+
+def get_voc_inference(args, voc_config):
+ # model: {model_name}_{dataset}
+ voc_name = args.voc[:args.voc.rindex('_')]
+ voc_class = dynamic_import(voc_name, model_alias)
+ voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
+ if voc_name != 'wavernn':
+ voc = voc_class(**voc_config["generator_params"])
+ voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
+ voc.remove_weight_norm()
+ voc.eval()
+ else:
+ voc = voc_class(**voc_config["model"])
+ voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
+ voc.eval()
+
+ voc_mu, voc_std = np.load(args.voc_stat)
+ voc_mu = paddle.to_tensor(voc_mu)
+ voc_std = paddle.to_tensor(voc_std)
+ voc_normalizer = ZScore(voc_mu, voc_std)
+ voc_inference = voc_inference_class(voc_normalizer, voc)
+ voc_inference.eval()
+ print("voc done!")
+ return voc_inference
+
+
+# to static
+def am_to_static(args, am_inference, am_name, am_dataset):
+ if am_name == 'fastspeech2':
+ if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
+ am_inference = jit.to_static(
+ am_inference,
+ input_spec=[
+ InputSpec([-1], dtype=paddle.int64),
+ InputSpec([1], dtype=paddle.int64),
+ ], )
+ else:
+ am_inference = jit.to_static(
+ am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
+
+ elif am_name == 'speedyspeech':
+ if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
+ am_inference = jit.to_static(
+ am_inference,
+ input_spec=[
+ InputSpec([-1], dtype=paddle.int64), # text
+ InputSpec([-1], dtype=paddle.int64), # tone
+ InputSpec([1], dtype=paddle.int64), # spk_id
+ None # duration
+ ])
+ else:
+ am_inference = jit.to_static(
+ am_inference,
+ input_spec=[
+ InputSpec([-1], dtype=paddle.int64),
+ InputSpec([-1], dtype=paddle.int64)
+ ])
+
+ elif am_name == 'tacotron2':
+ am_inference = jit.to_static(
+ am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
+
+ paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
+ am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am))
+ return am_inference
+
+
+def voc_to_static(args, voc_inference):
+ voc_inference = jit.to_static(
+ voc_inference, input_spec=[
+ InputSpec([-1, 80], dtype=paddle.float32),
+ ])
+ paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc))
+ voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc))
+ return voc_inference
diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py
index 81da14f2e..abb1eb4eb 100644
--- a/paddlespeech/t2s/exps/synthesize.py
+++ b/paddlespeech/t2s/exps/synthesize.py
@@ -23,48 +23,11 @@ import yaml
from timer import timer
from yacs.config import CfgNode
-from paddlespeech.s2t.utils.dynamic_import import dynamic_import
-from paddlespeech.t2s.datasets.data_table import DataTable
-from paddlespeech.t2s.modules.normalizer import ZScore
+from paddlespeech.t2s.exps.syn_utils import get_am_inference
+from paddlespeech.t2s.exps.syn_utils import get_test_dataset
+from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.utils import str2bool
-model_alias = {
- # acoustic model
- "speedyspeech":
- "paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
- "speedyspeech_inference":
- "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
- "fastspeech2":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
- "fastspeech2_inference":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
- "tacotron2":
- "paddlespeech.t2s.models.tacotron2:Tacotron2",
- "tacotron2_inference":
- "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
- # voc
- "pwgan":
- "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
- "pwgan_inference":
- "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
- "mb_melgan":
- "paddlespeech.t2s.models.melgan:MelGANGenerator",
- "mb_melgan_inference":
- "paddlespeech.t2s.models.melgan:MelGANInference",
- "style_melgan":
- "paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
- "style_melgan_inference":
- "paddlespeech.t2s.models.melgan:StyleMelGANInference",
- "hifigan":
- "paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
- "hifigan_inference":
- "paddlespeech.t2s.models.hifigan:HiFiGANInference",
- "wavernn":
- "paddlespeech.t2s.models.wavernn:WaveRNN",
- "wavernn_inference":
- "paddlespeech.t2s.models.wavernn:WaveRNNInference",
-}
-
def evaluate(args):
# dataloader has been too verbose
@@ -86,96 +49,12 @@ def evaluate(args):
print(am_config)
print(voc_config)
- # construct dataset for evaluation
-
- # model: {model_name}_{dataset}
- am_name = args.am[:args.am.rindex('_')]
- am_dataset = args.am[args.am.rindex('_') + 1:]
-
- if am_name == 'fastspeech2':
- fields = ["utt_id", "text"]
- spk_num = None
- if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
- print("multiple speaker fastspeech2!")
- with open(args.speaker_dict, 'rt') as f:
- spk_id = [line.strip().split() for line in f.readlines()]
- spk_num = len(spk_id)
- fields += ["spk_id"]
- elif args.voice_cloning:
- print("voice cloning!")
- fields += ["spk_emb"]
- else:
- print("single speaker fastspeech2!")
- print("spk_num:", spk_num)
- elif am_name == 'speedyspeech':
- fields = ["utt_id", "phones", "tones"]
- elif am_name == 'tacotron2':
- fields = ["utt_id", "text"]
- if args.voice_cloning:
- print("voice cloning!")
- fields += ["spk_emb"]
-
- test_dataset = DataTable(data=test_metadata, fields=fields)
-
- with open(args.phones_dict, "r") as f:
- phn_id = [line.strip().split() for line in f.readlines()]
- vocab_size = len(phn_id)
- print("vocab_size:", vocab_size)
-
- tone_size = None
- if args.tones_dict:
- with open(args.tones_dict, "r") as f:
- tone_id = [line.strip().split() for line in f.readlines()]
- tone_size = len(tone_id)
- print("tone_size:", tone_size)
-
# acoustic model
- odim = am_config.n_mels
- am_class = dynamic_import(am_name, model_alias)
- am_inference_class = dynamic_import(am_name + '_inference', model_alias)
-
- if am_name == 'fastspeech2':
- am = am_class(
- idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
- elif am_name == 'speedyspeech':
- am = am_class(
- vocab_size=vocab_size, tone_size=tone_size, **am_config["model"])
- elif am_name == 'tacotron2':
- am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
-
- am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
- am.eval()
- am_mu, am_std = np.load(args.am_stat)
- am_mu = paddle.to_tensor(am_mu)
- am_std = paddle.to_tensor(am_std)
- am_normalizer = ZScore(am_mu, am_std)
- am_inference = am_inference_class(am_normalizer, am)
- print("am_inference.training0:", am_inference.training)
- am_inference.eval()
- print("acoustic model done!")
+ am_inference, am_name, am_dataset = get_am_inference(args, am_config)
+ test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
# vocoder
- # model: {model_name}_{dataset}
- voc_name = args.voc[:args.voc.rindex('_')]
- voc_class = dynamic_import(voc_name, model_alias)
- voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
- if voc_name != 'wavernn':
- voc = voc_class(**voc_config["generator_params"])
- voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
- voc.remove_weight_norm()
- voc.eval()
- else:
- voc = voc_class(**voc_config["model"])
- voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
- voc.eval()
- voc_mu, voc_std = np.load(args.voc_stat)
- voc_mu = paddle.to_tensor(voc_mu)
- voc_std = paddle.to_tensor(voc_std)
- voc_normalizer = ZScore(voc_mu, voc_std)
- voc_inference = voc_inference_class(voc_normalizer, voc)
- print("voc_inference.training0:", voc_inference.training)
- voc_inference.eval()
- print("voc done!")
+ voc_inference = get_voc_inference(args, voc_config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -227,7 +106,7 @@ def evaluate(args):
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")
-def main():
+def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
@@ -264,7 +143,6 @@ def main():
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
-
parser.add_argument(
"--voice-cloning",
type=str2bool,
@@ -278,10 +156,10 @@ def main():
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
+ 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc'
],
help='Choose vocoder type of tts task.')
-
parser.add_argument(
'--voc_config',
type=str,
@@ -302,7 +180,12 @@ def main():
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py
index 94180f853..f5214d4a4 100644
--- a/paddlespeech/t2s/exps/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/synthesize_e2e.py
@@ -12,59 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
-import os
from pathlib import Path
-import numpy as np
import paddle
import soundfile as sf
import yaml
-from paddle import jit
-from paddle.static import InputSpec
from timer import timer
from yacs.config import CfgNode
-from paddlespeech.s2t.utils.dynamic_import import dynamic_import
-from paddlespeech.t2s.frontend import English
-from paddlespeech.t2s.frontend.zh_frontend import Frontend
-from paddlespeech.t2s.modules.normalizer import ZScore
-
-model_alias = {
- # acoustic model
- "speedyspeech":
- "paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
- "speedyspeech_inference":
- "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
- "fastspeech2":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
- "fastspeech2_inference":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
- "tacotron2":
- "paddlespeech.t2s.models.tacotron2:Tacotron2",
- "tacotron2_inference":
- "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
- # voc
- "pwgan":
- "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
- "pwgan_inference":
- "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
- "mb_melgan":
- "paddlespeech.t2s.models.melgan:MelGANGenerator",
- "mb_melgan_inference":
- "paddlespeech.t2s.models.melgan:MelGANInference",
- "style_melgan":
- "paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
- "style_melgan_inference":
- "paddlespeech.t2s.models.melgan:StyleMelGANInference",
- "hifigan":
- "paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
- "hifigan_inference":
- "paddlespeech.t2s.models.hifigan:HiFiGANInference",
- "wavernn":
- "paddlespeech.t2s.models.wavernn:WaveRNN",
- "wavernn_inference":
- "paddlespeech.t2s.models.wavernn:WaveRNNInference",
-}
+from paddlespeech.t2s.exps.syn_utils import am_to_static
+from paddlespeech.t2s.exps.syn_utils import get_am_inference
+from paddlespeech.t2s.exps.syn_utils import get_frontend
+from paddlespeech.t2s.exps.syn_utils import get_sentences
+from paddlespeech.t2s.exps.syn_utils import get_voc_inference
+from paddlespeech.t2s.exps.syn_utils import voc_to_static
def evaluate(args):
@@ -81,155 +42,28 @@ def evaluate(args):
print(am_config)
print(voc_config)
- # construct dataset for evaluation
- sentences = []
- with open(args.text, 'rt') as f:
- for line in f:
- items = line.strip().split()
- utt_id = items[0]
- if args.lang == 'zh':
- sentence = "".join(items[1:])
- elif args.lang == 'en':
- sentence = " ".join(items[1:])
- sentences.append((utt_id, sentence))
-
- with open(args.phones_dict, "r") as f:
- phn_id = [line.strip().split() for line in f.readlines()]
- vocab_size = len(phn_id)
- print("vocab_size:", vocab_size)
-
- tone_size = None
- if args.tones_dict:
- with open(args.tones_dict, "r") as f:
- tone_id = [line.strip().split() for line in f.readlines()]
- tone_size = len(tone_id)
- print("tone_size:", tone_size)
-
- spk_num = None
- if args.speaker_dict:
- with open(args.speaker_dict, 'rt') as f:
- spk_id = [line.strip().split() for line in f.readlines()]
- spk_num = len(spk_id)
- print("spk_num:", spk_num)
+ sentences = get_sentences(args)
# frontend
- if args.lang == 'zh':
- frontend = Frontend(
- phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
- elif args.lang == 'en':
- frontend = English(phone_vocab_path=args.phones_dict)
- print("frontend done!")
+ frontend = get_frontend(args)
# acoustic model
- odim = am_config.n_mels
- # model: {model_name}_{dataset}
- am_name = args.am[:args.am.rindex('_')]
- am_dataset = args.am[args.am.rindex('_') + 1:]
-
- am_class = dynamic_import(am_name, model_alias)
- am_inference_class = dynamic_import(am_name + '_inference', model_alias)
-
- if am_name == 'fastspeech2':
- am = am_class(
- idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
- elif am_name == 'speedyspeech':
- am = am_class(
- vocab_size=vocab_size,
- tone_size=tone_size,
- spk_num=spk_num,
- **am_config["model"])
- elif am_name == 'tacotron2':
- am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
-
- am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
- am.eval()
- am_mu, am_std = np.load(args.am_stat)
- am_mu = paddle.to_tensor(am_mu)
- am_std = paddle.to_tensor(am_std)
- am_normalizer = ZScore(am_mu, am_std)
- am_inference = am_inference_class(am_normalizer, am)
- am_inference.eval()
- print("acoustic model done!")
+ am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# vocoder
- # model: {model_name}_{dataset}
- voc_name = args.voc[:args.voc.rindex('_')]
- voc_class = dynamic_import(voc_name, model_alias)
- voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
- if voc_name != 'wavernn':
- voc = voc_class(**voc_config["generator_params"])
- voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
- voc.remove_weight_norm()
- voc.eval()
- else:
- voc = voc_class(**voc_config["model"])
- voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
- voc.eval()
-
- voc_mu, voc_std = np.load(args.voc_stat)
- voc_mu = paddle.to_tensor(voc_mu)
- voc_std = paddle.to_tensor(voc_std)
- voc_normalizer = ZScore(voc_mu, voc_std)
- voc_inference = voc_inference_class(voc_normalizer, voc)
- voc_inference.eval()
- print("voc done!")
+ voc_inference = get_voc_inference(args, voc_config)
# whether dygraph to static
if args.inference_dir:
# acoustic model
- if am_name == 'fastspeech2':
- if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
- am_inference = jit.to_static(
- am_inference,
- input_spec=[
- InputSpec([-1], dtype=paddle.int64),
- InputSpec([1], dtype=paddle.int64)
- ])
- else:
- am_inference = jit.to_static(
- am_inference,
- input_spec=[InputSpec([-1], dtype=paddle.int64)])
-
- elif am_name == 'speedyspeech':
- if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
- am_inference = jit.to_static(
- am_inference,
- input_spec=[
- InputSpec([-1], dtype=paddle.int64), # text
- InputSpec([-1], dtype=paddle.int64), # tone
- InputSpec([1], dtype=paddle.int64), # spk_id
- None # duration
- ])
- else:
- am_inference = jit.to_static(
- am_inference,
- input_spec=[
- InputSpec([-1], dtype=paddle.int64),
- InputSpec([-1], dtype=paddle.int64)
- ])
-
- elif am_name == 'tacotron2':
- am_inference = jit.to_static(
- am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
-
- paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
- am_inference = paddle.jit.load(
- os.path.join(args.inference_dir, args.am))
+ am_inference = am_to_static(args, am_inference, am_name, am_dataset)
# vocoder
- voc_inference = jit.to_static(
- voc_inference,
- input_spec=[
- InputSpec([-1, 80], dtype=paddle.float32),
- ])
- paddle.jit.save(voc_inference,
- os.path.join(args.inference_dir, args.voc))
- voc_inference = paddle.jit.load(
- os.path.join(args.inference_dir, args.voc))
+ voc_inference = voc_to_static(args, voc_inference)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
- merge_sentences = False
+ merge_sentences = True
# Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph
# but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
if am_name == 'tacotron2':
@@ -298,7 +132,7 @@ def evaluate(args):
print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }")
-def main():
+def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
@@ -346,12 +180,19 @@ def main():
type=str,
default='pwgan_csmsc',
choices=[
- 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
- 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc',
- 'wavernn_csmsc'
+ 'pwgan_csmsc',
+ 'pwgan_ljspeech',
+ 'pwgan_aishell3',
+ 'pwgan_vctk',
+ 'mb_melgan_csmsc',
+ 'style_melgan_csmsc',
+ 'hifigan_csmsc',
+ 'hifigan_ljspeech',
+ 'hifigan_aishell3',
+ 'hifigan_vctk',
+ 'wavernn_csmsc',
],
help='Choose vocoder type of tts task.')
-
parser.add_argument(
'--voc_config',
type=str,
@@ -386,6 +227,11 @@ def main():
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py
index 3de30774f..1afd21dff 100644
--- a/paddlespeech/t2s/exps/voice_cloning.py
+++ b/paddlespeech/t2s/exps/voice_cloning.py
@@ -21,29 +21,12 @@ import soundfile as sf
import yaml
from yacs.config import CfgNode
-from paddlespeech.s2t.utils.dynamic_import import dynamic_import
+from paddlespeech.t2s.exps.syn_utils import get_am_inference
+from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.frontend.zh_frontend import Frontend
-from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
-model_alias = {
- # acoustic model
- "fastspeech2":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2",
- "fastspeech2_inference":
- "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
- "tacotron2":
- "paddlespeech.t2s.models.tacotron2:Tacotron2",
- "tacotron2_inference":
- "paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
- # voc
- "pwgan":
- "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
- "pwgan_inference":
- "paddlespeech.t2s.models.parallel_wavegan:PWGInference",
-}
-
def voice_cloning(args):
# Init body.
@@ -79,55 +62,14 @@ def voice_cloning(args):
speaker_encoder.eval()
print("GE2E Done!")
- with open(args.phones_dict, "r") as f:
- phn_id = [line.strip().split() for line in f.readlines()]
- vocab_size = len(phn_id)
- print("vocab_size:", vocab_size)
+ frontend = Frontend(phone_vocab_path=args.phones_dict)
+ print("frontend done!")
# acoustic model
- odim = am_config.n_mels
- # model: {model_name}_{dataset}
- am_name = args.am[:args.am.rindex('_')]
- am_dataset = args.am[args.am.rindex('_') + 1:]
-
- am_class = dynamic_import(am_name, model_alias)
- am_inference_class = dynamic_import(am_name + '_inference', model_alias)
-
- if am_name == 'fastspeech2':
- am = am_class(
- idim=vocab_size, odim=odim, spk_num=None, **am_config["model"])
- elif am_name == 'tacotron2':
- am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
-
- am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
- am.eval()
- am_mu, am_std = np.load(args.am_stat)
- am_mu = paddle.to_tensor(am_mu)
- am_std = paddle.to_tensor(am_std)
- am_normalizer = ZScore(am_mu, am_std)
- am_inference = am_inference_class(am_normalizer, am)
- am_inference.eval()
- print("acoustic model done!")
+ am_inference, *_ = get_am_inference(args, am_config)
# vocoder
- # model: {model_name}_{dataset}
- voc_name = args.voc[:args.voc.rindex('_')]
- voc_class = dynamic_import(voc_name, model_alias)
- voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
- voc = voc_class(**voc_config["generator_params"])
- voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
- voc.remove_weight_norm()
- voc.eval()
- voc_mu, voc_std = np.load(args.voc_stat)
- voc_mu = paddle.to_tensor(voc_mu)
- voc_std = paddle.to_tensor(voc_std)
- voc_normalizer = ZScore(voc_mu, voc_std)
- voc_inference = voc_inference_class(voc_normalizer, voc)
- voc_inference.eval()
- print("voc done!")
-
- frontend = Frontend(phone_vocab_path=args.phones_dict)
- print("frontend done!")
+ voc_inference = get_voc_inference(args, voc_config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -170,7 +112,7 @@ def voice_cloning(args):
print(f"{utt_id} done!")
-def main():
+def parse_args():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="")
parser.add_argument(
@@ -240,6 +182,11 @@ def main():
parser.add_argument("--output-dir", type=str, help="output dir.")
args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
diff --git a/paddlespeech/t2s/modules/predictor/length_regulator.py b/paddlespeech/t2s/modules/predictor/length_regulator.py
index 62d707d22..2472c413b 100644
--- a/paddlespeech/t2s/modules/predictor/length_regulator.py
+++ b/paddlespeech/t2s/modules/predictor/length_regulator.py
@@ -101,6 +101,16 @@ class LengthRegulator(nn.Layer):
assert alpha > 0
ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha)
ds = ds.cast(dtype=paddle.int64)
+ '''
+ from distutils.version import LooseVersion
+ from paddlespeech.t2s.modules.nets_utils import pad_list
+ # 这里在 paddle 2.2.2 的动转静是不通的
+ # if LooseVersion(paddle.__version__) >= "2.3.0" or hasattr(paddle, 'repeat_interleave'):
+ # if LooseVersion(paddle.__version__) >= "2.3.0":
+ if hasattr(paddle, 'repeat_interleave'):
+ repeat = [paddle.repeat_interleave(x, d, axis=0) for x, d in zip(xs, ds)]
+ return pad_list(repeat, self.pad_value)
+ '''
if is_inference:
return self.expand(xs, ds)
else:
diff --git a/setup.py b/setup.py
index f86758bab..82ff63412 100644
--- a/setup.py
+++ b/setup.py
@@ -27,7 +27,7 @@ from setuptools.command.install import install
HERE = Path(os.path.abspath(os.path.dirname(__file__)))
-VERSION = '0.1.2'
+VERSION = '0.2.0'
base = [
"editdistance",
diff --git a/speechx/.gitignore b/speechx/.gitignore
new file mode 100644
index 000000000..e0c618470
--- /dev/null
+++ b/speechx/.gitignore
@@ -0,0 +1 @@
+tools/valgrind*
diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt
index e003136a9..f1330d1da 100644
--- a/speechx/CMakeLists.txt
+++ b/speechx/CMakeLists.txt
@@ -2,18 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(paddlespeech VERSION 0.1)
+set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake")
+
set(CMAKE_VERBOSE_MAKEFILE on)
+
# set std-14
set(CMAKE_CXX_STANDARD 14)
-# include file
+# cmake dir
+set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
+
+# Modules
+list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
+list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent)
include(ExternalProject)
+
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
+# compiler option
+# Keep the same with openfst, -fPIC or -fpic
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g")
+SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
+SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
###############################################################################
# Option Configurations
@@ -25,91 +39,92 @@ option(TEST_DEBUG "option for debug" OFF)
###############################################################################
# Include third party
###############################################################################
-# #example for include third party
-# FetchContent_Declare()
-# # FetchContent_MakeAvailable was not added until CMake 3.14
+# example for include third party
+# FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
+# gflags
+include(gflags)
+
+# glog
+include(glog)
+
+# gtest
+include(gtest)
+
# ABSEIL-CPP
-include(FetchContent)
-FetchContent_Declare(
- absl
- GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
- GIT_TAG "20210324.1"
-)
-FetchContent_MakeAvailable(absl)
+include(absl)
# libsndfile
-include(FetchContent)
-FetchContent_Declare(
- libsndfile
- GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
- GIT_TAG "1.0.31"
-)
-FetchContent_MakeAvailable(libsndfile)
+include(libsndfile)
-# gflags
-FetchContent_Declare(
- gflags
- URL https://github.com/gflags/gflags/archive/v2.2.1.zip
- URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
-)
-FetchContent_MakeAvailable(gflags)
-include_directories(${gflags_BINARY_DIR}/include)
+# boost
+# include(boost) # not work
+set(boost_SOURCE_DIR ${fc_patch}/boost-src)
+set(BOOST_ROOT ${boost_SOURCE_DIR})
+# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
-# glog
-FetchContent_Declare(
- glog
- URL https://github.com/google/glog/archive/v0.4.0.zip
- URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
-)
-FetchContent_MakeAvailable(glog)
-include_directories(${glog_BINARY_DIR})
+# Eigen
+include(eigen)
+find_package(Eigen3 REQUIRED)
-# gtest
-FetchContent_Declare(googletest
- URL https://github.com/google/googletest/archive/release-1.10.0.zip
- URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
-)
-FetchContent_MakeAvailable(googletest)
+# Kenlm
+include(kenlm)
+add_dependencies(kenlm eigen boost)
+
+#openblas
+include(openblas)
# openfst
-set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
-set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
-set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix)
-ExternalProject_Add(openfst
- URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
- URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
- SOURCE_DIR ${openfst_SOURCE_DIR}
- BINARY_DIR ${openfst_BINARY_DIR}
- CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
- "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
- "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
- "LIBS=-lgflags_nothreads -lglog -lpthread"
- BUILD_COMMAND make -j 4
-)
+include(openfst)
add_dependencies(openfst gflags glog)
-link_directories(${openfst_PREFIX_DIR}/lib)
-include_directories(${openfst_PREFIX_DIR}/include)
-add_subdirectory(speechx)
-#openblas
-#set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/OpenBLAS)
-#set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
-#ExternalProject_Add(
-# OpenBLAS
-# GIT_REPOSITORY https://github.com/xianyi/OpenBLAS
-# GIT_TAG v0.3.13
-# GIT_SHALLOW TRUE
-# GIT_PROGRESS TRUE
-# CONFIGURE_COMMAND ""
-# BUILD_IN_SOURCE TRUE
-# BUILD_COMMAND make USE_LOCKING=1 USE_THREAD=0
-# INSTALL_COMMAND make PREFIX=${OpenBLAS_INSTALL_PREFIX} install
-# UPDATE_DISCONNECTED TRUE
-#)
+# paddle lib
+set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
+set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
+ExternalProject_Add(paddle
+ URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
+ URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
+ PREFIX ${paddle_PREFIX_DIR}
+ SOURCE_DIR ${paddle_SOURCE_DIR}
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+)
+
+set(PADDLE_LIB ${fc_patch}/paddle-lib)
+include_directories("${PADDLE_LIB}/paddle/include")
+set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
+
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
+link_directories("${PADDLE_LIB}/paddle/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
+
+##paddle with mkl
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
+set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
+include_directories("${MATH_LIB_PATH}/include")
+set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
+include_directories("${MKLDNN_PATH}/include")
+set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+
+set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
+set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf xxhash cryptopp
+ ${EXTERNAL_LIB})
+
+
###############################################################################
# Add local library
@@ -121,4 +136,9 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1)
-#add_dependencies(lib_name depend-target)
\ No newline at end of file
+#add_dependencies(lib_name depend-target)
+
+set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
+
+add_subdirectory(speechx)
+add_subdirectory(examples)
\ No newline at end of file
diff --git a/speechx/README.md b/speechx/README.md
new file mode 100644
index 000000000..7d73b61c6
--- /dev/null
+++ b/speechx/README.md
@@ -0,0 +1,61 @@
+# SpeechX -- All in One Speech Task Inference
+
+## Environment
+
+We develop under:
+* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7
+* os - Ubuntu 16.04.7 LTS
+* gcc/g++ - 8.2.0
+* cmake - 3.16.0
+
+> We make sure all things work fun under docker, and recommend using it to develop and deploy.
+
+* [How to Install Docker](https://docs.docker.com/engine/install/)
+* [A Docker Tutorial for Beginners](https://docker-curriculum.com/)
+* [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html)
+
+## Build
+
+1. First to launch docker container.
+
+```
+nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 /bin/bash
+```
+
+* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
+
+* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`.
+
+
+2. Build `speechx` and `examples`.
+
+```
+pushd /path/to/speechx
+./build.sh
+```
+
+3. Go to `examples` to have a fun.
+
+More details please see `README.md` under `examples`.
+
+
+## Valgrind (Optional)
+
+> If using docker please check `--privileged` is set when `docker run`.
+
+* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up`
+```
+apt-get install libc6-dbg
+```
+
+* Install
+
+```
+pushd tools
+./setup_valgrind.sh
+popd
+```
+
+## TODO
+
+* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result.
diff --git a/speechx/build.sh b/speechx/build.sh
new file mode 100755
index 000000000..3e9600d53
--- /dev/null
+++ b/speechx/build.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+
+# the build script had verified in the paddlepaddle docker image.
+# please follow the instruction below to install PaddlePaddle image.
+# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
+
+boost_SOURCE_DIR=$PWD/fc_patch/boost-src
+if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
+ tar xzfv boost_1_75_0.tar.gz
+ mkdir -p $PWD/fc_patch
+ mv boost_1_75_0 ${boost_SOURCE_DIR}
+ cd ${boost_SOURCE_DIR}
+ bash ./bootstrap.sh
+ ./b2
+ cd -
+ echo -e "\n"
+fi
+
+#rm -rf build
+mkdir -p build
+cd build
+
+cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
+#cmake ..
+
+make -j1
+
+cd -
diff --git a/speechx/cmake/EnableCMP0048.cmake b/speechx/cmake/EnableCMP0048.cmake
new file mode 100644
index 000000000..1b59188fd
--- /dev/null
+++ b/speechx/cmake/EnableCMP0048.cmake
@@ -0,0 +1 @@
+cmake_policy(SET CMP0048 NEW)
\ No newline at end of file
diff --git a/speechx/cmake/external/absl.cmake b/speechx/cmake/external/absl.cmake
new file mode 100644
index 000000000..2c5e5af5c
--- /dev/null
+++ b/speechx/cmake/external/absl.cmake
@@ -0,0 +1,16 @@
+include(FetchContent)
+
+
+set(BUILD_SHARED_LIBS OFF) # up to you
+set(BUILD_TESTING OFF) # to disable abseil test, or gtest will fail.
+set(ABSL_ENABLE_INSTALL ON) # now you can enable install rules even in subproject...
+
+FetchContent_Declare(
+ absl
+ GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
+ GIT_TAG "20210324.1"
+)
+FetchContent_MakeAvailable(absl)
+
+set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
+include_directories(${absl_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/boost.cmake b/speechx/cmake/external/boost.cmake
new file mode 100644
index 000000000..6bc97aad4
--- /dev/null
+++ b/speechx/cmake/external/boost.cmake
@@ -0,0 +1,27 @@
+include(FetchContent)
+set(Boost_DEBUG ON)
+
+set(Boost_PREFIX_DIR ${fc_patch}/boost)
+set(Boost_SOURCE_DIR ${fc_patch}/boost-src)
+
+FetchContent_Declare(
+ Boost
+ URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
+ URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
+ PREFIX ${Boost_PREFIX_DIR}
+ SOURCE_DIR ${Boost_SOURCE_DIR}
+)
+
+execute_process(COMMAND bootstrap.sh WORKING_DIRECTORY ${Boost_SOURCE_DIR})
+execute_process(COMMAND b2 WORKING_DIRECTORY ${Boost_SOURCE_DIR})
+
+FetchContent_MakeAvailable(Boost)
+
+message(STATUS "boost src dir: ${Boost_SOURCE_DIR}")
+message(STATUS "boost inc dir: ${Boost_INCLUDE_DIR}")
+message(STATUS "boost bin dir: ${Boost_BINARY_DIR}")
+
+set(BOOST_ROOT ${Boost_SOURCE_DIR})
+message(STATUS "boost root dir: ${BOOST_ROOT}")
+
+include_directories(${Boost_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/eigen.cmake b/speechx/cmake/external/eigen.cmake
new file mode 100644
index 000000000..12bd3cdf5
--- /dev/null
+++ b/speechx/cmake/external/eigen.cmake
@@ -0,0 +1,27 @@
+include(FetchContent)
+
+# update eigen to the commit id f612df27 on 03/16/2021
+set(EIGEN_PREFIX_DIR ${fc_patch}/eigen3)
+
+FetchContent_Declare(
+ Eigen3
+ GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
+ GIT_TAG master
+ PREFIX ${EIGEN_PREFIX_DIR}
+ GIT_SHALLOW TRUE
+ GIT_PROGRESS TRUE)
+
+set(EIGEN_BUILD_DOC OFF)
+# note: To disable eigen tests,
+# you should put this code in a add_subdirectory to avoid to change
+# BUILD_TESTING for your own project too since variables are directory
+# scoped
+set(BUILD_TESTING OFF)
+set(EIGEN_BUILD_PKGCONFIG OFF)
+set( OFF)
+FetchContent_MakeAvailable(Eigen3)
+
+message(STATUS "eigen src dir: ${Eigen3_SOURCE_DIR}")
+message(STATUS "eigen bin dir: ${Eigen3_BINARY_DIR}")
+#include_directories(${Eigen3_SOURCE_DIR})
+#link_directories(${Eigen3_BINARY_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/gflags.cmake b/speechx/cmake/external/gflags.cmake
new file mode 100644
index 000000000..66ae47f70
--- /dev/null
+++ b/speechx/cmake/external/gflags.cmake
@@ -0,0 +1,12 @@
+include(FetchContent)
+
+FetchContent_Declare(
+ gflags
+ URL https://github.com/gflags/gflags/archive/v2.2.1.zip
+ URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
+)
+
+FetchContent_MakeAvailable(gflags)
+
+# openfst need
+include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
diff --git a/speechx/cmake/external/glog.cmake b/speechx/cmake/external/glog.cmake
new file mode 100644
index 000000000..dcfd86c3e
--- /dev/null
+++ b/speechx/cmake/external/glog.cmake
@@ -0,0 +1,8 @@
+include(FetchContent)
+FetchContent_Declare(
+ glog
+ URL https://github.com/google/glog/archive/v0.4.0.zip
+ URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
+)
+FetchContent_MakeAvailable(glog)
+include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
diff --git a/speechx/cmake/external/gtest.cmake b/speechx/cmake/external/gtest.cmake
new file mode 100644
index 000000000..7fe397fcb
--- /dev/null
+++ b/speechx/cmake/external/gtest.cmake
@@ -0,0 +1,9 @@
+include(FetchContent)
+FetchContent_Declare(
+ gtest
+ URL https://github.com/google/googletest/archive/release-1.10.0.zip
+ URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
+)
+FetchContent_MakeAvailable(gtest)
+
+include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
\ No newline at end of file
diff --git a/speechx/cmake/external/kenlm.cmake b/speechx/cmake/external/kenlm.cmake
new file mode 100644
index 000000000..17c76c3f6
--- /dev/null
+++ b/speechx/cmake/external/kenlm.cmake
@@ -0,0 +1,10 @@
+include(FetchContent)
+FetchContent_Declare(
+ kenlm
+ GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
+ GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
+)
+# https://github.com/kpu/kenlm/blob/master/cmake/modules/FindEigen3.cmake
+set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
+FetchContent_MakeAvailable(kenlm)
+include_directories(${kenlm_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/libsndfile.cmake b/speechx/cmake/external/libsndfile.cmake
new file mode 100644
index 000000000..52d64bacd
--- /dev/null
+++ b/speechx/cmake/external/libsndfile.cmake
@@ -0,0 +1,56 @@
+include(FetchContent)
+
+# https://github.com/pongasoft/vst-sam-spl-64/blob/master/libsndfile.cmake
+# https://github.com/popojan/goban/blob/master/CMakeLists.txt#L38
+# https://github.com/ddiakopoulos/libnyquist/blob/master/CMakeLists.txt
+
+if(LIBSNDFILE_ROOT_DIR)
+ # instructs FetchContent to not download or update but use the location instead
+ set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE ${LIBSNDFILE_ROOT_DIR})
+else()
+ set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE "")
+endif()
+
+set(LIBSNDFILE_GIT_REPO "https://github.com/libsndfile/libsndfile.git" CACHE STRING "libsndfile git repository url" FORCE)
+set(LIBSNDFILE_GIT_TAG 1.0.31 CACHE STRING "libsndfile git tag" FORCE)
+
+FetchContent_Declare(libsndfile
+ GIT_REPOSITORY ${LIBSNDFILE_GIT_REPO}
+ GIT_TAG ${LIBSNDFILE_GIT_TAG}
+ GIT_CONFIG advice.detachedHead=false
+# GIT_SHALLOW true
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+ )
+
+FetchContent_GetProperties(libsndfile)
+if(NOT libsndfile_POPULATED)
+ if(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE)
+ message(STATUS "Using libsndfile from local ${FETCHCONTENT_SOURCE_DIR_LIBSNDFILE}")
+ else()
+ message(STATUS "Fetching libsndfile ${LIBSNDFILE_GIT_REPO}/tree/${LIBSNDFILE_GIT_TAG}")
+ endif()
+ FetchContent_Populate(libsndfile)
+endif()
+
+set(LIBSNDFILE_ROOT_DIR ${libsndfile_SOURCE_DIR})
+set(LIBSNDFILE_INCLUDE_DIR "${libsndfile_BINARY_DIR}/src")
+
+function(libsndfile_build)
+ option(BUILD_PROGRAMS "Build programs" OFF)
+ option(BUILD_EXAMPLES "Build examples" OFF)
+ option(BUILD_TESTING "Build examples" OFF)
+ option(ENABLE_CPACK "Enable CPack support" OFF)
+ option(ENABLE_PACKAGE_CONFIG "Generate and install package config file" OFF)
+ option(BUILD_REGTEST "Build regtest" OFF)
+ # finally we include libsndfile itself
+ add_subdirectory(${libsndfile_SOURCE_DIR} ${libsndfile_BINARY_DIR} EXCLUDE_FROM_ALL)
+ # copying .hh for c++ support
+ #file(COPY "${libsndfile_SOURCE_DIR}/src/sndfile.hh" DESTINATION ${LIBSNDFILE_INCLUDE_DIR})
+endfunction()
+
+libsndfile_build()
+
+include_directories(${LIBSNDFILE_INCLUDE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/openblas.cmake b/speechx/cmake/external/openblas.cmake
new file mode 100644
index 000000000..3c202f7f6
--- /dev/null
+++ b/speechx/cmake/external/openblas.cmake
@@ -0,0 +1,37 @@
+include(FetchContent)
+
+set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
+set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix)
+
+# ######################################################################################################################
+# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
+# ######################################################################################################################
+enable_language(Fortran)
+#TODO: switch to CPM
+include(GNUInstallDirs)
+ExternalProject_Add(
+ OPENBLAS
+ GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
+ GIT_TAG v0.3.10
+ GIT_SHALLOW YES
+ PREFIX ${OpenBLAS_PREFIX}
+ SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
+ CMAKE_GENERATOR "Unix Makefiles")
+
+
+# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
+ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
+set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
+add_library(openblas STATIC IMPORTED)
+add_dependencies(openblas OPENBLAS)
+set_target_properties(openblas PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES Fortran)
+# ${CMAKE_INSTALL_LIBDIR} lib
+set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a)
+
+
+# https://cmake.org/cmake/help/latest/command/install.html?highlight=cmake_install_libdir#installing-targets
+# ${CMAKE_INSTALL_LIBDIR} lib
+# ${CMAKE_INSTALL_INCLUDEDIR} include
+link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
+include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/openfst.cmake b/speechx/cmake/external/openfst.cmake
new file mode 100644
index 000000000..07abb18e8
--- /dev/null
+++ b/speechx/cmake/external/openfst.cmake
@@ -0,0 +1,19 @@
+include(FetchContent)
+set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
+set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
+
+ExternalProject_Add(openfst
+ URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
+ URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
+# #PREFIX ${openfst_PREFIX_DIR}
+# SOURCE_DIR ${openfst_SOURCE_DIR}
+# BINARY_DIR ${openfst_BINARY_DIR}
+ CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
+ "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
+ "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
+ "LIBS=-lgflags_nothreads -lglog -lpthread"
+ COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
+ BUILD_COMMAND make -j 4
+)
+link_directories(${openfst_PREFIX_DIR}/lib)
+include_directories(${openfst_PREFIX_DIR}/include)
diff --git a/speechx/examples/.gitignore b/speechx/examples/.gitignore
new file mode 100644
index 000000000..b7075fa56
--- /dev/null
+++ b/speechx/examples/.gitignore
@@ -0,0 +1,2 @@
+*.ark
+paddle_asr_model/
diff --git a/speechx/examples/CMakeLists.txt b/speechx/examples/CMakeLists.txt
new file mode 100644
index 000000000..ef0a72b88
--- /dev/null
+++ b/speechx/examples/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_subdirectory(feat)
+add_subdirectory(nnet)
+add_subdirectory(decoder)
diff --git a/speechx/examples/README.md b/speechx/examples/README.md
new file mode 100644
index 000000000..941c4272d
--- /dev/null
+++ b/speechx/examples/README.md
@@ -0,0 +1,16 @@
+# Examples
+
+* decoder - online decoder to work as offline
+* feat - mfcc, linear
+* nnet - ds2 nn
+
+## How to run
+
+`run.sh` is the entry point.
+
+Example to play `decoder`:
+
+```
+pushd decoder
+bash run.sh
+```
diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt
new file mode 100644
index 000000000..4bd5c6cf0
--- /dev/null
+++ b/speechx/examples/decoder/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
+target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc
new file mode 100644
index 000000000..44127c73b
--- /dev/null
+++ b/speechx/examples/decoder/offline_decoder_main.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "base/flags.h"
+#include "base/log.h"
+#include "decoder/ctc_beam_search_decoder.h"
+#include "frontend/raw_audio.h"
+#include "kaldi/util/table-types.h"
+#include "nnet/decodable.h"
+#include "nnet/paddle_nnet.h"
+
+DEFINE_string(feature_respecifier, "", "test feature rspecifier");
+DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
+DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
+DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
+DEFINE_string(lm_path, "lm.klm", "language model");
+
+
+using kaldi::BaseFloat;
+using kaldi::Matrix;
+using std::vector;
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialBaseFloatMatrixReader feature_reader(
+ FLAGS_feature_respecifier);
+ std::string model_graph = FLAGS_model_path;
+ std::string model_params = FLAGS_param_path;
+ std::string dict_file = FLAGS_dict_file;
+ std::string lm_path = FLAGS_lm_path;
+
+ int32 num_done = 0, num_err = 0;
+
+ ppspeech::CTCBeamSearchOptions opts;
+ opts.dict_file = dict_file;
+ opts.lm_path = lm_path;
+ ppspeech::CTCBeamSearch decoder(opts);
+
+ ppspeech::ModelOptions model_opts;
+ model_opts.model_path = model_graph;
+ model_opts.params_path = model_params;
+ std::shared_ptr nnet(
+ new ppspeech::PaddleNnet(model_opts));
+ std::shared_ptr raw_data(
+ new ppspeech::RawDataCache());
+ std::shared_ptr decodable(
+ new ppspeech::Decodable(nnet, raw_data));
+
+ int32 chunk_size = 35;
+ decoder.InitDecoder();
+
+ for (; !feature_reader.Done(); feature_reader.Next()) {
+ string utt = feature_reader.Key();
+ const kaldi::Matrix feature = feature_reader.Value();
+ raw_data->SetDim(feature.NumCols());
+ int32 row_idx = 0;
+ int32 num_chunks = feature.NumRows() / chunk_size;
+ for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
+ kaldi::Vector feature_chunk(chunk_size *
+ feature.NumCols());
+ for (int row_id = 0; row_id < chunk_size; ++row_id) {
+ kaldi::SubVector tmp(feature, row_idx);
+ kaldi::SubVector f_chunk_tmp(
+ feature_chunk.Data() + row_id * feature.NumCols(),
+ feature.NumCols());
+ f_chunk_tmp.CopyFromVec(tmp);
+ row_idx++;
+ }
+ raw_data->Accept(feature_chunk);
+ if (chunk_idx == num_chunks - 1) {
+ raw_data->SetFinished();
+ }
+ decoder.AdvanceDecode(decodable);
+ }
+ std::string result;
+ result = decoder.GetFinalBestPath();
+ KALDI_LOG << " the result of " << utt << " is " << result;
+ decodable->Reset();
+ decoder.Reset();
+ ++num_done;
+ }
+
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/examples/decoder/path.sh b/speechx/examples/decoder/path.sh
new file mode 100644
index 000000000..7b4b7545b
--- /dev/null
+++ b/speechx/examples/decoder/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/decoder/run.sh b/speechx/examples/decoder/run.sh
new file mode 100755
index 000000000..fc5e91824
--- /dev/null
+++ b/speechx/examples/decoder/run.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+# 3. run feat
+linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
+
+# 4. run decoder
+offline_decoder_main \
+ --feature_respecifier=ark:$feat_wspecifier \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams \
+ --dict_file=$model_dir/vocab.txt \
+ --lm_path=$model_dir/avg_1.jit.klm
\ No newline at end of file
diff --git a/speechx/examples/decoder/valgrind.sh b/speechx/examples/decoder/valgrind.sh
new file mode 100755
index 000000000..14efe0ba4
--- /dev/null
+++ b/speechx/examples/decoder/valgrind.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ offline_decoder_main \
+ --feature_respecifier=ark:$feat_wspecifier \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams \
+ --dict_file=$model_dir/vocab.txt \
+ --lm_path=$model_dir/avg_1.jit.klm
+
diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt
new file mode 100644
index 000000000..b8f516afb
--- /dev/null
+++ b/speechx/examples/feat/CMakeLists.txt
@@ -0,0 +1,10 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+
+add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
+target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(mfcc-test kaldi-mfcc)
+
+add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
+target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
\ No newline at end of file
diff --git a/speechx/examples/feat/feature-mfcc-test.cc b/speechx/examples/feat/feature-mfcc-test.cc
new file mode 100644
index 000000000..ae32aba9e
--- /dev/null
+++ b/speechx/examples/feat/feature-mfcc-test.cc
@@ -0,0 +1,720 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// feat/feature-mfcc-test.cc
+
+// Copyright 2009-2011 Karel Vesely; Petr Motlicek
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include
+
+#include "base/kaldi-math.h"
+#include "feat/feature-mfcc.h"
+#include "feat/wave-reader.h"
+#include "matrix/kaldi-matrix-inl.h"
+
+using namespace kaldi;
+
+
+static void UnitTestReadWave() {
+ std::cout << "=== UnitTestReadWave() ===\n";
+
+ Vector v, v2;
+
+ std::cout << "<<<=== Reading waveform\n";
+
+ {
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ const Matrix data(wave.Data());
+ KALDI_ASSERT(data.NumRows() == 1);
+ v.Resize(data.NumCols());
+ v.CopyFromVec(data.Row(0));
+ }
+
+ std::cout
+ << "<<<=== Reading Vector waveform, prepared by matlab\n";
+ std::ifstream input("test_data/test_matlab.ascii");
+ KALDI_ASSERT(input.good());
+ v2.Read(input, false);
+ input.close();
+
+ std::cout
+ << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
+ KALDI_ASSERT(v.Dim() == v2.Dim());
+ for (int32 i = 0; i < v.Dim(); i++) {
+ KALDI_ASSERT(v(i) == v2(i));
+ }
+ std::cout << "<<<=== Comparing done\n";
+
+ // std::cout << "== The Waveform Samples == \n";
+ // std::cout << v;
+
+ std::cout << "Test passed :)\n\n";
+}
+
+
+/**
+ */
+static void UnitTestSimple() {
+ std::cout << "=== UnitTestSimple() ===\n";
+
+ Vector v(100000);
+ Matrix m;
+
+ // init with noise
+ for (int32 i = 0; i < v.Dim(); i++) {
+ v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
+ }
+
+ std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
+ // the parametrization object
+ MfccOptions op;
+ // trying to have same opts as baseline.
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "rectangular";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+
+ Mfcc mfcc(op);
+ // use default parameters
+
+ // compute mfccs.
+ mfcc.Compute(v, 1.0, &m);
+
+ // possibly dump
+ // std::cout << "== Output features == \n" << m;
+ std::cout << "Test passed :)\n\n";
+}
+
+
+static void UnitTestHTKCompare1() {
+ std::cout << "=== UnitTestHTKCompare1() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.1",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+ op.use_energy = false; // C0 not energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (i_old != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.1",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.1");
+}
+
+
+static void UnitTestHTKCompare2() {
+ std::cout << "=== UnitTestHTKCompare2() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.2",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (i_old != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.2",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.2");
+}
+
+
+static void UnitTestHTKCompare3() {
+ std::cout << "=== UnitTestHTKCompare3() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.3",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.low_freq = 20.0;
+ // op.mel_opts.debug_mel = true;
+ op.mel_opts.htk_mode = true;
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.3",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.3");
+}
+
+
+static void UnitTestHTKCompare4() {
+ std::cout << "=== UnitTestHTKCompare4() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.4",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.htk_mode = true;
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.4",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.4");
+}
+
+
+static void UnitTestHTKCompare5() {
+ std::cout << "=== UnitTestHTKCompare5() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.5",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.vtln_low = 100.0;
+ op.mel_opts.vtln_high = 7500.0;
+ op.mel_opts.htk_mode = true;
+
+ BaseFloat vtln_warp =
+ 1.1; // our approach identical to htk for warp factor >1,
+ // differs slightly for higher mel bins if warp_factor <0.9
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.5",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.5");
+}
+
+static void UnitTestHTKCompare6() {
+ std::cout << "=== UnitTestHTKCompare6() ===\n";
+
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.6",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.97;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.num_bins = 24;
+ op.mel_opts.low_freq = 125.0;
+ op.mel_opts.high_freq = 7800.0;
+ op.htk_compat = true;
+ op.use_energy = false; // C0 not energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.6",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.6");
+}
+
+void UnitTestVtln() {
+ // Test the function VtlnWarpFreq.
+ BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
+ vtln_high_cutoff = 7400;
+
+ for (size_t i = 0; i < 100; i++) {
+ BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq),
+ freq / warp_factor);
+
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ low_freq),
+ low_freq);
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ high_freq),
+ high_freq);
+ BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
+ freq3 = freq2 +
+ (high_freq - freq2) * RandUniform(); // freq3>=freq2
+ BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq2);
+ BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq3);
+ KALDI_ASSERT(w3 >= w2); // increasing function.
+ BaseFloat w3dash = MelBanks::VtlnWarpFreq(
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
+ AssertEqual(w3dash, freq3);
+ }
+}
+
+static void UnitTestFeat() {
+ UnitTestVtln();
+ UnitTestReadWave();
+ UnitTestSimple();
+ UnitTestHTKCompare1();
+ UnitTestHTKCompare2();
+ // commenting out this one as it doesn't compare right now I normalized
+ // the way the FFT bins are treated (removed offset of 0.5)... this seems
+ // to relate to the way frequency zero behaves.
+ UnitTestHTKCompare3();
+ UnitTestHTKCompare4();
+ UnitTestHTKCompare5();
+ UnitTestHTKCompare6();
+ std::cout << "Tests succeeded.\n";
+}
+
+
+int main() {
+ try {
+ for (int i = 0; i < 5; i++) UnitTestFeat();
+ std::cout << "Tests succeeded.\n";
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what();
+ return 1;
+ }
+}
diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc
new file mode 100644
index 000000000..9ed4d6f93
--- /dev/null
+++ b/speechx/examples/feat/linear_spectrogram_main.cc
@@ -0,0 +1,248 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "frontend/linear_spectrogram.h"
+#include "base/flags.h"
+#include "base/log.h"
+#include "frontend/feature_cache.h"
+#include "frontend/feature_extractor_interface.h"
+#include "frontend/normalizer.h"
+#include "frontend/raw_audio.h"
+#include "kaldi/feat/wave-reader.h"
+#include "kaldi/util/kaldi-io.h"
+#include "kaldi/util/table-types.h"
+
+DEFINE_string(wav_rspecifier, "", "test wav scp path");
+DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
+DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
+
+
+std::vector mean_{
+ -13730251.531853663, -12982852.199316509, -13673844.299583456,
+ -13089406.559646806, -12673095.524938712, -12823859.223276224,
+ -13590267.158903603, -14257618.467152044, -14374605.116185192,
+ -14490009.21822485, -14849827.158924166, -15354435.470563512,
+ -15834149.206532761, -16172971.985514281, -16348740.496746974,
+ -16423536.699409386, -16556246.263649225, -16744088.772748645,
+ -16916184.08510357, -17054034.840031497, -17165612.509455364,
+ -17255955.470915023, -17322572.527648456, -17408943.862033736,
+ -17521554.799865916, -17620623.254924215, -17699792.395918526,
+ -17723364.411134344, -17741483.4433254, -17747426.888704527,
+ -17733315.928209435, -17748780.160905756, -17808336.883775543,
+ -17895918.671983004, -18009812.59173023, -18098188.66548325,
+ -18195798.958462656, -18293617.62980999, -18397432.92077201,
+ -18505834.787318766, -18585451.8100908, -18652438.235649142,
+ -18700960.306275308, -18734944.58792185, -18737426.313365128,
+ -18735347.165987637, -18738813.444170244, -18737086.848890636,
+ -18731576.2474336, -18717405.44095871, -18703089.25545657,
+ -18691014.546456724, -18692460.568905357, -18702119.628629155,
+ -18727710.621126678, -18761582.72034647, -18806745.835547544,
+ -18850674.8692112, -18884431.510951452, -18919999.992506847,
+ -18939303.799078144, -18952946.273760635, -18980289.22996379,
+ -19011610.17803294, -19040948.61805145, -19061021.429847397,
+ -19112055.53768819, -19149667.414264943, -19201127.05091321,
+ -19270250.82564605, -19334606.883057203, -19390513.336589377,
+ -19444176.259208687, -19502755.000038862, -19544333.014549147,
+ -19612668.183176614, -19681902.19006569, -19771969.951249883,
+ -19873329.723376893, -19996752.59235844, -20110031.131400537,
+ -20231658.612529557, -20319378.894054495, -20378534.45718066,
+ -20413332.089584175, -20438147.844177883, -20443710.248040095,
+ -20465457.02238927, -20488610.969337028, -20516295.16424432,
+ -20541423.795738827, -20553192.874953747, -20573605.50701977,
+ -20577871.61936797, -20571807.008916274, -20556242.38912231,
+ -20542199.30819195, -20521239.063551214, -20519150.80004532,
+ -20527204.80248933, -20536933.769257784, -20543470.522332076,
+ -20549700.089992985, -20551525.24958494, -20554873.406493705,
+ -20564277.65794227, -20572211.740052115, -20574305.69550465,
+ -20575494.450104576, -20567092.577932164, -20549302.929608088,
+ -20545445.11878376, -20546625.326603737, -20549190.03499401,
+ -20554824.947828256, -20568341.378989458, -20577582.331383612,
+ -20577980.519402675, -20566603.03458152, -20560131.592262644,
+ -20552166.469060015, -20549063.06763577, -20544490.562339947,
+ -20539817.82346569, -20528747.715731595, -20518026.24576161,
+ -20510977.844974525, -20506874.36087992, -20506731.11977665,
+ -20510482.133420516, -20507760.92101862, -20494644.834457114,
+ -20480107.89304893, -20461312.091867123, -20442941.75080173,
+ -20426123.02834838, -20424607.675283, -20426810.369107097,
+ -20434024.50097819, -20437404.75544205, -20447688.63916367,
+ -20460893.335563846, -20482922.735127095, -20503610.119434915,
+ -20527062.76448319, -20557830.035128627, -20593274.72068722,
+ -20632528.452965066, -20673637.471334763, -20733106.97143075,
+ -20842921.0447562, -21054357.83621519, -21416569.534189366,
+ -21978460.272811692, -22753170.052172784, -23671344.10563395,
+ -24613499.293358143, -25406477.12230188, -25884377.82156489,
+ -26049040.62791664, -26996879.104431007};
+std::vector variance_{
+ 213747175.10846674, 188395815.34302503, 212706429.10966414,
+ 199109025.81461075, 189235901.23864496, 194901336.53253657,
+ 217481594.29306737, 238689869.12327808, 243977501.24115244,
+ 248479623.6431067, 259766741.47116545, 275516766.7790273,
+ 291271202.3691234, 302693239.8220509, 308627358.3997694,
+ 311143911.38788426, 315446105.07731867, 321705430.9341829,
+ 327458907.4659941, 332245072.43223983, 336251717.5935284,
+ 339694069.7639722, 342188204.4322228, 345587110.31313115,
+ 349903086.2875232, 353660214.20643026, 356700344.5270885,
+ 357665362.3529641, 358493352.05658793, 358857951.620328,
+ 358375239.52774596, 358899733.6342954, 361051818.3511561,
+ 364361716.05025816, 368750322.3771452, 372047800.6462831,
+ 375655861.1349018, 379358519.1980013, 383327605.3935181,
+ 387458599.282341, 390434692.3406868, 392994486.35057056,
+ 394874418.04603153, 396230525.79763395, 396365592.0414835,
+ 396334819.8242737, 396488353.19250053, 396438877.00744957,
+ 396197980.4459586, 395590921.6672991, 395001107.62072515,
+ 394528291.7318225, 394593110.424006, 395018405.59353715,
+ 396110577.5415993, 397506704.0371068, 399400197.4657644,
+ 401243568.2468382, 402687134.7805103, 404136047.2872507,
+ 404883170.001883, 405522253.219517, 406660365.3626476,
+ 407919346.0991902, 409045348.5384909, 409759588.7889818,
+ 411974821.8564483, 413489718.78201455, 415535392.56684107,
+ 418466481.97674364, 421104678.35678065, 423405392.5200779,
+ 425550570.40798235, 427929423.9579701, 429585274.253478,
+ 432368493.55181056, 435193587.13513297, 438886855.20476013,
+ 443058876.8633751, 448181232.5093362, 452883835.6332396,
+ 458056721.77926534, 461816531.22735566, 464363620.1970998,
+ 465886343.5057493, 466928872.0651, 467180536.42647296,
+ 468111848.70714295, 469138695.3071312, 470378429.6930793,
+ 471517958.7132626, 472109050.4262365, 473087417.0177867,
+ 473381322.04648733, 473220195.85483915, 472666071.8998819,
+ 472124669.87879956, 471298571.411737, 471251033.2902761,
+ 471672676.43128747, 472177147.2193172, 472572361.7711908,
+ 472968783.7751127, 473156295.4164052, 473398034.82676554,
+ 473897703.5203811, 474328271.33112127, 474452670.98002136,
+ 474549003.99284613, 474252887.13567275, 473557462.909069,
+ 473483385.85193115, 473609738.04855174, 473746944.82085115,
+ 474016729.91696435, 474617321.94138587, 475045097.237122,
+ 475125402.586558, 474664112.9824912, 474426247.5800283,
+ 474104075.42796475, 473978219.7273978, 473773171.7798875,
+ 473578534.69508696, 473102924.16904145, 472651240.5232615,
+ 472374383.1810912, 472209479.6956096, 472202298.8921673,
+ 472370090.76781124, 472220933.99374026, 471625467.37106377,
+ 470994646.51883453, 470182428.9637543, 469348211.5939578,
+ 468570387.4467277, 468540442.7225135, 468672018.90414184,
+ 468994346.9533251, 469138757.58201426, 469553915.95710236,
+ 470134523.38582784, 471082421.62055486, 471962316.51804745,
+ 472939745.1708408, 474250621.5944825, 475773933.43199486,
+ 477465399.71087736, 479218782.61382693, 481752299.7930922,
+ 486608947.8984568, 496119403.2067917, 512730085.5704984,
+ 539048915.2641417, 576285298.3548826, 621610270.2240586,
+ 669308196.4436442, 710656993.5957186, 736344437.3725077,
+ 745481288.0241544, 801121432.9925804};
+int count_ = 912592;
+
+void WriteMatrix() {
+ kaldi::Matrix cmvn_stats(2, mean_.size() + 1);
+ for (size_t idx = 0; idx < mean_.size(); ++idx) {
+ cmvn_stats(0, idx) = mean_[idx];
+ cmvn_stats(1, idx) = variance_[idx];
+ }
+ cmvn_stats(0, mean_.size()) = count_;
+ kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
+}
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialTableReader wav_reader(
+ FLAGS_wav_rspecifier);
+ kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
+ WriteMatrix();
+
+ // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
+ // window -->linear_spectrogram --> cmvn
+ int32 num_done = 0, num_err = 0;
+ // std::unique_ptr data_source(new
+ // ppspeech::RawDataCache());
+ std::unique_ptr data_source(
+ new ppspeech::RawAudioCache());
+
+ ppspeech::LinearSpectrogramOptions opt;
+ opt.frame_opts.frame_length_ms = 20;
+ opt.frame_opts.frame_shift_ms = 10;
+ ppspeech::DecibelNormalizerOptions db_norm_opt;
+ std::unique_ptr base_feature_extractor(
+ new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
+
+ std::unique_ptr linear_spectrogram(
+ new ppspeech::LinearSpectrogram(opt,
+ std::move(base_feature_extractor)));
+
+ std::unique_ptr cmvn(
+ new ppspeech::CMVN(FLAGS_cmvn_write_path,
+ std::move(linear_spectrogram)));
+
+ ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
+
+ float streaming_chunk = 0.36;
+ int sample_rate = 16000;
+ int chunk_sample_size = streaming_chunk * sample_rate;
+
+ for (; !wav_reader.Done(); wav_reader.Next()) {
+ std::string utt = wav_reader.Key();
+ const kaldi::WaveData& wave_data = wav_reader.Value();
+
+ int32 this_channel = 0;
+ kaldi::SubVector waveform(wave_data.Data(),
+ this_channel);
+ int tot_samples = waveform.Dim();
+ int sample_offset = 0;
+ std::vector> feats;
+ int feature_rows = 0;
+ while (sample_offset < tot_samples) {
+ int cur_chunk_size =
+ std::min(chunk_sample_size, tot_samples - sample_offset);
+
+ kaldi::Vector wav_chunk(cur_chunk_size);
+ for (int i = 0; i < cur_chunk_size; ++i) {
+ wav_chunk(i) = waveform(sample_offset + i);
+ }
+ kaldi::Vector features;
+ feature_cache.Accept(wav_chunk);
+ if (cur_chunk_size < chunk_sample_size) {
+ feature_cache.SetFinished();
+ }
+ feature_cache.Read(&features);
+ if (features.Dim() == 0) break;
+
+ feats.push_back(features);
+ sample_offset += cur_chunk_size;
+ feature_rows += features.Dim() / feature_cache.Dim();
+ }
+
+ int cur_idx = 0;
+ kaldi::Matrix features(feature_rows,
+ feature_cache.Dim());
+ for (auto feat : feats) {
+ int num_rows = feat.Dim() / feature_cache.Dim();
+ for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
+ for (size_t col_idx = 0; col_idx < feature_cache.Dim();
+ ++col_idx) {
+ features(cur_idx, col_idx) =
+ feat(row_idx * feature_cache.Dim() + col_idx);
+ }
+ ++cur_idx;
+ }
+ }
+ feat_writer.Write(utt, features);
+
+ if (num_done % 50 == 0 && num_done != 0)
+ KALDI_VLOG(2) << "Processed " << num_done << " utterances";
+ num_done++;
+ }
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/examples/feat/path.sh b/speechx/examples/feat/path.sh
new file mode 100644
index 000000000..8ab7ee299
--- /dev/null
+++ b/speechx/examples/feat/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/feat/run.sh b/speechx/examples/feat/run.sh
new file mode 100755
index 000000000..bd21bd7f4
--- /dev/null
+++ b/speechx/examples/feat/run.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+set +x
+set -e
+
+. ./path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+# 3. run feat
+linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
diff --git a/speechx/examples/feat/valgrind.sh b/speechx/examples/feat/valgrind.sh
new file mode 100755
index 000000000..f8aab63f8
--- /dev/null
+++ b/speechx/examples/feat/valgrind.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
+
diff --git a/speechx/examples/nnet/CMakeLists.txt b/speechx/examples/nnet/CMakeLists.txt
new file mode 100644
index 000000000..20f4008ce
--- /dev/null
+++ b/speechx/examples/nnet/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
+target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
\ No newline at end of file
diff --git a/speechx/examples/nnet/path.sh b/speechx/examples/nnet/path.sh
new file mode 100644
index 000000000..f70e70eea
--- /dev/null
+++ b/speechx/examples/nnet/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/nnet/pp-model-test.cc b/speechx/examples/nnet/pp-model-test.cc
new file mode 100644
index 000000000..2db354a79
--- /dev/null
+++ b/speechx/examples/nnet/pp-model-test.cc
@@ -0,0 +1,193 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "paddle_inference_api.h"
+
+using std::cout;
+using std::endl;
+
+DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
+DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
+
+
+void produce_data(std::vector>* data);
+void model_forward_test();
+
+void produce_data(std::vector>* data) {
+ int chunk_size = 35; // chunk_size in frame
+ int col_size = 161; // feat dim
+ cout << "chunk size: " << chunk_size << endl;
+ cout << "feat dim: " << col_size << endl;
+
+ data->reserve(chunk_size);
+ data->back().reserve(col_size);
+ for (int row = 0; row < chunk_size; ++row) {
+ data->push_back(std::vector());
+ for (int col_idx = 0; col_idx < col_size; ++col_idx) {
+ data->back().push_back(0.201);
+ }
+ }
+}
+
+void model_forward_test() {
+ std::cout << "1. read the data" << std::endl;
+ std::vector> feats;
+ produce_data(&feats);
+
+ std::cout << "2. load the model" << std::endl;
+ ;
+ std::string model_graph = FLAGS_model_path;
+ std::string model_params = FLAGS_param_path;
+ cout << "model path: " << model_graph << endl;
+ cout << "model param path : " << model_params << endl;
+
+ paddle_infer::Config config;
+ config.SetModel(model_graph, model_params);
+ config.SwitchIrOptim(false);
+ cout << "SwitchIrOptim: " << false << endl;
+ config.DisableFCPadding();
+ cout << "DisableFCPadding: " << endl;
+ auto predictor = paddle_infer::CreatePredictor(config);
+
+ std::cout << "3. feat shape, row=" << feats.size()
+ << ",col=" << feats[0].size() << std::endl;
+ std::vector pp_input_mat;
+ for (const auto& item : feats) {
+ pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
+ }
+
+ std::cout << "4. fead the data to model" << std::endl;
+ int row = feats.size();
+ int col = feats[0].size();
+ std::vector input_names = predictor->GetInputNames();
+ std::vector output_names = predictor->GetOutputNames();
+ for (auto name : input_names) {
+ cout << "model input names: " << name << endl;
+ }
+ for (auto name : output_names) {
+ cout << "model output names: " << name << endl;
+ }
+
+ // input
+ std::unique_ptr input_tensor =
+ predictor->GetInputHandle(input_names[0]);
+ std::vector INPUT_SHAPE = {1, row, col};
+ input_tensor->Reshape(INPUT_SHAPE);
+ input_tensor->CopyFromCpu(pp_input_mat.data());
+
+ // input length
+ std::unique_ptr input_len =
+ predictor->GetInputHandle(input_names[1]);
+ std::vector input_len_size = {1};
+ input_len->Reshape(input_len_size);
+ std::vector audio_len;
+ audio_len.push_back(row);
+ input_len->CopyFromCpu(audio_len.data());
+
+ // state_h
+ std::unique_ptr chunk_state_h_box =
+ predictor->GetInputHandle(input_names[2]);
+ std::vector chunk_state_h_box_shape = {3, 1, 1024};
+ chunk_state_h_box->Reshape(chunk_state_h_box_shape);
+ int chunk_state_h_box_size =
+ std::accumulate(chunk_state_h_box_shape.begin(),
+ chunk_state_h_box_shape.end(),
+ 1,
+ std::multiplies());
+ std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
+ chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
+
+ // state_c
+ std::unique_ptr chunk_state_c_box =
+ predictor->GetInputHandle(input_names[3]);
+ std::vector chunk_state_c_box_shape = {3, 1, 1024};
+ chunk_state_c_box->Reshape(chunk_state_c_box_shape);
+ int chunk_state_c_box_size =
+ std::accumulate(chunk_state_c_box_shape.begin(),
+ chunk_state_c_box_shape.end(),
+ 1,
+ std::multiplies());
+ std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
+ chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
+
+ // run
+ bool success = predictor->Run();
+
+ // state_h out
+ std::unique_ptr h_out =
+ predictor->GetOutputHandle(output_names[2]);
+ std::vector h_out_shape = h_out->shape();
+ int h_out_size = std::accumulate(
+ h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies());
+ std::vector h_out_data(h_out_size);
+ h_out->CopyToCpu(h_out_data.data());
+
+ // stage_c out
+ std::unique_ptr c_out =
+ predictor->GetOutputHandle(output_names[3]);
+ std::vector c_out_shape = c_out->shape();
+ int c_out_size = std::accumulate(
+ c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies());
+ std::vector c_out_data(c_out_size);
+ c_out->CopyToCpu(c_out_data.data());
+
+ // output tensor
+ std::unique_ptr output_tensor =
+ predictor->GetOutputHandle(output_names[0]);
+ std::vector output_shape = output_tensor->shape();
+ std::vector output_probs;
+ int output_size = std::accumulate(
+ output_shape.begin(), output_shape.end(), 1, std::multiplies());
+ output_probs.resize(output_size);
+ output_tensor->CopyToCpu(output_probs.data());
+ row = output_shape[1];
+ col = output_shape[2];
+
+ // probs
+ std::vector> probs;
+ probs.reserve(row);
+ for (int i = 0; i < row; i++) {
+ probs.push_back(std::vector());
+ probs.back().reserve(col);
+
+ for (int j = 0; j < col; j++) {
+ probs.back().push_back(output_probs[i * col + j]);
+ }
+ }
+
+ std::vector> log_feat = probs;
+ std::cout << "probs, row: " << log_feat.size()
+ << " col: " << log_feat[0].size() << std::endl;
+ for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
+ for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
+ ++col_idx) {
+ std::cout << log_feat[row_idx][col_idx] << " ";
+ }
+ std::cout << std::endl;
+ }
+}
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ model_forward_test();
+ return 0;
+}
diff --git a/speechx/examples/nnet/run.sh b/speechx/examples/nnet/run.sh
new file mode 100755
index 000000000..4d67d1988
--- /dev/null
+++ b/speechx/examples/nnet/run.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+
+# 4. run decoder
+pp-model-test \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams
+
diff --git a/speechx/examples/nnet/valgrind.sh b/speechx/examples/nnet/valgrind.sh
new file mode 100755
index 000000000..2a08c6082
--- /dev/null
+++ b/speechx/examples/nnet/valgrind.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ pp-model-test \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams
\ No newline at end of file
diff --git a/speechx/patch/CPPLINT.cfg b/speechx/patch/CPPLINT.cfg
new file mode 100644
index 000000000..51ff339c1
--- /dev/null
+++ b/speechx/patch/CPPLINT.cfg
@@ -0,0 +1 @@
+exclude_files=.*
diff --git a/speechx/patch/openfst/src/include/fst/flags.h b/speechx/patch/openfst/src/include/fst/flags.h
new file mode 100644
index 000000000..b5ec8ff74
--- /dev/null
+++ b/speechx/patch/openfst/src/include/fst/flags.h
@@ -0,0 +1,228 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// See www.openfst.org for extensive documentation on this weighted
+// finite-state transducer library.
+//
+// Google-style flag handling declarations and inline definitions.
+
+#ifndef FST_LIB_FLAGS_H_
+#define FST_LIB_FLAGS_H_
+
+#include
+
+#include
+#include