diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7fb01708..09e92a66 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
- exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
+ exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
- exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
+ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:
diff --git a/demos/audio_searching/README.md b/demos/audio_searching/README.md
new file mode 100644
index 00000000..2b417c0e
--- /dev/null
+++ b/demos/audio_searching/README.md
@@ -0,0 +1,171 @@
+([简体中文](./README_cn.md)|English)
+
+# Audio Searching
+
+## Introduction
+As the Internet continues to evolve, unstructured data such as emails, social media photos, live videos, and customer service voice calls have become increasingly common. If we want to process the data on a computer, we need to use embedding technology to transform the data into vector and store, index, and query it
+
+However, when there is a large amount of data, such as hundreds of millions of audio tracks, it is more difficult to do a similarity search. The exhaustive method is feasible, but very time consuming. For this scenario, this demo will introduce how to build an audio similarity retrieval system using the open source vector database Milvus
+
+Audio retrieval (speech, music, speaker, etc.) enables querying and finding similar sounds (or the same speaker) in a large amount of audio data. The audio similarity retrieval system can be used to identify similar sound effects, minimize intellectual property infringement, quickly retrieve the voice print library, and help enterprises control fraud and identity theft. Audio retrieval also plays an important role in the classification and statistical analysis of audio data
+
+In this demo, you will learn how to build an audio retrieval system to retrieve similar sound snippets. The uploaded audio clips are converted into vector data using paddlespeech-based pre-training models (audio classification model, speaker recognition model, etc.) and stored in Milvus. Milvus automatically generates a unique ID for each vector, then stores the ID and the corresponding audio information (audio ID, audio speaker ID, etc.) in MySQL to complete the library construction. During retrieval, users upload test audio to obtain vector, and then conduct vector similarity search in Milvus. The retrieval result returned by Milvus is vector ID, and the corresponding audio information can be queried in MySQL by ID
+
+![Workflow of an audio searching system](./img/audio_searching.png)
+
+Note:this demo uses the [CN-Celeb](http://openslr.org/82/) dataset of at least 650,000 audio entries and 3000 speakers to build the audio vector library, which is then retrieved using a preset distance calculation. The dataset can also use other, Adjust as needed, e.g. Librispeech, VoxCeleb, UrbanSound, GloVe, MNIST, etc
+
+## Usage
+### 1. Prepare MySQL and Milvus services by docker-compose
+The audio similarity search system requires Milvus, MySQL services. We can start these containers with one click through [docker-compose.yaml](./docker-compose.yaml), so please make sure you have [installed Docker Engine](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) before running. then
+
+```bash
+docker-compose -f docker-compose.yaml up -d
+```
+
+Then you will see the that all containers are created:
+
+```bash
+Creating network "quick_deploy_app_net" with driver "bridge"
+Creating milvus-minio ... done
+Creating milvus-etcd ... done
+Creating audio-mysql ... done
+Creating milvus-standalone ... done
+Creating audio-webclient ... done
+```
+
+And show all containers with `docker ps`, and you can use `docker logs audio-mysql` to get the logs of server container
+
+```bash
+CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
+b2bcf279e599 milvusdb/milvus:v2.0.1 "/tini -- milvus run…" 22 hours ago Up 22 hours 0.0.0.0:19530->19530/tcp milvus-standalone
+d8ef4c84e25c mysql:5.7 "docker-entrypoint.s…" 22 hours ago Up 22 hours 0.0.0.0:3306->3306/tcp, 33060/tcp audio-mysql
+8fb501edb4f3 quay.io/coreos/etcd:v3.5.0 "etcd -advertise-cli…" 22 hours ago Up 22 hours 2379-2380/tcp milvus-etcd
+ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 22 hours ago Up 22 hours (healthy) 9000/tcp milvus-minio
+15c84a506754 iregistry.baidu-int.com/paddlespeech/audio-search-client:1.0 "/bin/bash -c '/usr/…" 22 hours ago Up 22 hours (healthy) 0.0.0.0:8068->80/tcp audio-webclient
+```
+
+### 2. Start API Server
+Then to start the system server, and it provides HTTP backend services.
+
+- Install the Python packages
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+- Set configuration
+
+ ```bash
+ vim src/config.py
+ ```
+
+ Modify the parameters according to your own environment. Here listing some parameters that need to be set, for more information please refer to [config.py](./src/config.py).
+
+ | **Parameter** | **Description** | **Default setting** |
+ | ---------------- | ----------------------------------------------------- | ------------------- |
+ | MILVUS_HOST | The IP address of Milvus, you can get it by ifconfig. If running everything on one machine, most likely 127.0.0.1 | 127.0.0.1 |
+ | MILVUS_PORT | Port of Milvus. | 19530 |
+ | VECTOR_DIMENSION | Dimension of the vectors. | 2048 |
+ | MYSQL_HOST | The IP address of Mysql. | 127.0.0.1 |
+ | MYSQL_PORT | Port of Milvus. | 3306 |
+ | DEFAULT_TABLE | The milvus and mysql default collection name. | audio_table |
+
+- Run the code
+
+ Then start the server with Fastapi.
+
+ ```bash
+ export PYTHONPATH=$PYTHONPATH:./src
+ python src/main.py
+ ```
+
+ Then you will see the Application is started:
+
+ ```bash
+ INFO: Started server process [3949]
+ 2022-03-07 17:39:14,864 | INFO | server.py | serve | 75 | Started server process [3949]
+ INFO: Waiting for application startup.
+ 2022-03-07 17:39:14,865 | INFO | on.py | startup | 45 | Waiting for application startup.
+ INFO: Application startup complete.
+ 2022-03-07 17:39:14,866 | INFO | on.py | startup | 59 | Application startup complete.
+ INFO: Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ 2022-03-07 17:39:14,867 | INFO | server.py | _log_started_message | 206 | Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ ```
+
+### 3. Usage
+- Prepare data
+ ```bash
+ wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
+ ```
+ Note: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
+
+ - scripts test (recommend!)
+
+ The internal process is downloading data, loading the Paddlespeech model, extracting embedding, storing library, retrieving and deleting library
+ ```bash
+ python ./src/test_main.py
+ ```
+
+ Output:
+ ```bash
+ Checkpoint path: %your model path%
+ Extracting feature from audio No. 1 , 20 audios in total
+ Extracting feature from audio No. 2 , 20 audios in total
+ ...
+ 2022-03-09 17:22:13,870 | INFO | main.py | load_audios | 85 | Successfully loaded data, total count: 20
+ 2022-03-09 17:22:13,898 | INFO | main.py | count_audio | 147 | Successfully count the number of data!
+ 2022-03-09 17:22:13,918 | INFO | main.py | audio_path | 57 | Successfully load audio: ./example_audio/test.wav
+ ...
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/test.wav, distance 0.0
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_chopping.wav, distance 0.021805256605148315
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_cut_into_flesh.wav, distance 0.052762262523174286
+ ...
+ 2022-03-09 17:22:32,582 | INFO | main.py | search_local_audio | 135 | Successfully searched similar audio!
+ 2022-03-09 17:22:33,658 | INFO | main.py | drop_tables | 159 | Successfully drop tables in Milvus and MySQL!
+ ```
+- GUI test (optional)
+
+ Navigate to 127.0.0.1:8068 in your browser to access the front-end interface
+
+ Note: If the browser and the service are not on the same machine, then the IP needs to be changed to the IP of the machine where the service is located, and the corresponding API_URL in docker-compose.yaml needs to be changed and the service can be restarted
+
+ - Insert data
+
+ Download the data and decompress it to a path named /home/speech/data. Then enter /home/speech/data in the address bar of the upload page to upload the data
+
+ ![](./img/insert.png)
+
+ - Search for similar audio
+
+ Select the magnifying glass icon on the left side of the interface. Then, press the "Default Target Audio File" button and upload a .wav sound file you'd like to search. Results will be displayed
+
+ ![](./img/search.png)
+
+### 4.Result
+
+ machine configuration:
+- OS: CentOS release 7.6
+- kernel:4.17.11-1.el7.elrepo.x86_64
+- CPU:Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
+- memory:132G
+
+dataset:
+- CN-Celeb, train size 650,000, test size 10,000, dimention 192, distance L2
+
+recall and elapsed time statistics are shown in the following figure:
+
+ ![](./img/result.png)
+
+
+The retrieval framework based on Milvus takes about 2.9 milliseconds to retrieve on the premise of 90% recall rate, and it takes about 500 milliseconds for feature extraction (testing audio takes about 5 seconds), that is, a single audio test takes about 503 milliseconds in total, which can meet most application scenarios
+
+### 5.Pretrained Models
+
+Here is a list of pretrained models released by PaddleSpeech :
+
+| Model | Sample Rate
+| :--- | :---:
+| ecapa_tdnn | 16000
+| panns_cnn6| 32000
+| panns_cnn10| 32000
+| panns_cnn14| 32000
diff --git a/demos/audio_searching/README_cn.md b/demos/audio_searching/README_cn.md
new file mode 100644
index 00000000..d822c00d
--- /dev/null
+++ b/demos/audio_searching/README_cn.md
@@ -0,0 +1,172 @@
+
+(简体中文|[English](./README.md))
+
+# 音频相似性检索
+## 介绍
+
+随着互联网不断发展,电子邮件、社交媒体照片、直播视频、客服语音等非结构化数据已经变得越来越普遍。如果想要使用计算机来处理这些数据,需要使用 embedding 技术将这些数据转化为向量 vector,然后进行存储、建索引、并查询
+
+但是,当数据量很大,比如上亿条音频要做相似度搜索,就比较困难了。穷举法固然可行,但非常耗时。针对这种场景,该demo 将介绍如何使用开源向量数据库 Milvus 搭建音频相似度检索系统
+
+音频检索(如演讲、音乐、说话人等检索)实现了在海量音频数据中查询并找出相似声音(或相同说话人)片段。音频相似性检索系统可用于识别相似的音效、最大限度减少知识产权侵权等,还可以快速的检索声纹库、帮助企业控制欺诈和身份盗用等。在音频数据的分类和统计分析中,音频检索也发挥着重要作用
+
+在本 demo 中,你将学会如何构建一个音频检索系统,用来检索相似的声音片段。使用基于 PaddleSpeech 预训练模型(音频分类模型,说话人识别模型等)将上传的音频片段转换为向量数据,并存储在 Milvus 中。Milvus 自动为每个向量生成唯一的 ID,然后将 ID 和 相应的音频信息(音频id,音频的说话人id等等)存储在 MySQL,这样就完成建库的工作。用户在检索时,上传测试音频,得到向量,然后在 Milvus 中进行向量相似度搜索,Milvus 返回的检索结果为向量 ID,通过 ID 在 MySQL 内部查询相应的音频信息即可
+
+![音频检索流程图](./img/audio_searching.png)
+
+注:该 demo 使用 [CN-Celeb](http://openslr.org/82/) 数据集,包括至少 650000 条音频,3000 个说话人,来建立音频向量库(音频特征,或音频说话人特征),然后通过预设的距离计算方式进行音频(或说话人)检索,这里面数据集也可以使用其他的,根据需要调整,如Librispeech,VoxCeleb,UrbanSound,GloVe,MNIST等
+
+## 使用方法
+### 1. MySQL 和 Milvus 安装
+音频相似度搜索系统需要用到 Milvus, MySQL 服务。 我们可以通过 [docker-compose.yaml](./docker-compose.yaml) 一键启动这些容器,所以请确保在运行之前已经安装了 [Docker Engine](https://docs.docker.com/engine/install/) 和 [Docker Compose](https://docs.docker.com/compose/install/)。 即
+
+```bash
+docker-compose -f docker-compose.yaml up -d
+```
+
+然后你会看到所有的容器都被创建:
+
+```bash
+Creating network "quick_deploy_app_net" with driver "bridge"
+Creating milvus-minio ... done
+Creating milvus-etcd ... done
+Creating audio-mysql ... done
+Creating milvus-standalone ... done
+Creating audio-webclient ... done
+```
+
+可以采用'docker ps'来显示所有的容器,还可以使用'docker logs audio-mysql'来获取服务器容器的日志:
+
+```bash
+CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
+b2bcf279e599 milvusdb/milvus:v2.0.1 "/tini -- milvus run…" 22 hours ago Up 22 hours 0.0.0.0:19530->19530/tcp milvus-standalone
+d8ef4c84e25c mysql:5.7 "docker-entrypoint.s…" 22 hours ago Up 22 hours 0.0.0.0:3306->3306/tcp, 33060/tcp audio-mysql
+8fb501edb4f3 quay.io/coreos/etcd:v3.5.0 "etcd -advertise-cli…" 22 hours ago Up 22 hours 2379-2380/tcp milvus-etcd
+ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…" 22 hours ago Up 22 hours (healthy) 9000/tcp milvus-minio
+15c84a506754 iregistry.baidu-int.com/paddlespeech/audio-search-client:1.0 "/bin/bash -c '/usr/…" 22 hours ago Up 22 hours (healthy) 0.0.0.0:8068->80/tcp audio-webclient
+
+```
+
+### 2. 配置并启动 API 服务
+启动系统服务程序,它会提供基于 Http 后端服务
+
+- 安装服务依赖的 python 基础包
+
+ ```bash
+ pip install -r requirements.txt
+ ```
+- 修改配置
+
+ ```bash
+ vim src/config.py
+ ```
+
+ 请根据实际环境进行修改。 这里列出了一些需要设置的参数,更多信息请参考 [config.py](./src/config.py)
+
+ | **Parameter** | **Description** | **Default setting** |
+ | ---------------- | ----------------------------------------------------- | ------------------- |
+ | MILVUS_HOST | The IP address of Milvus, you can get it by ifconfig. If running everything on one machine, most likely 127.0.0.1 | 127.0.0.1 |
+ | MILVUS_PORT | Port of Milvus. | 19530 |
+ | VECTOR_DIMENSION | Dimension of the vectors. | 2048 |
+ | MYSQL_HOST | The IP address of Mysql. | 127.0.0.1 |
+ | MYSQL_PORT | Port of Milvus. | 3306 |
+ | DEFAULT_TABLE | The milvus and mysql default collection name. | audio_table |
+
+- 运行程序
+
+ 启动用 Fastapi 构建的服务
+
+ ```bash
+ export PYTHONPATH=$PYTHONPATH:./src
+ python src/main.py
+ ```
+
+ 然后你会看到应用程序启动:
+
+ ```bash
+ INFO: Started server process [3949]
+ 2022-03-07 17:39:14,864 | INFO | server.py | serve | 75 | Started server process [3949]
+ INFO: Waiting for application startup.
+ 2022-03-07 17:39:14,865 | INFO | on.py | startup | 45 | Waiting for application startup.
+ INFO: Application startup complete.
+ 2022-03-07 17:39:14,866 | INFO | on.py | startup | 59 | Application startup complete.
+ INFO: Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ 2022-03-07 17:39:14,867 | INFO | server.py | _log_started_message | 206 | Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)
+ ```
+
+### 3. 测试方法
+- 准备数据
+ ```bash
+ wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
+ ```
+ 注:如果希望快速搭建 demo,可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
+
+ - 脚本测试(推荐)
+
+ ```bash
+ python ./src/test_main.py
+ ```
+ 注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding,存储建库,检索,删库
+
+ 输出:
+ ```bash
+ Checkpoint path: %your model path%
+ Extracting feature from audio No. 1 , 20 audios in total
+ Extracting feature from audio No. 2 , 20 audios in total
+ ...
+ 2022-03-09 17:22:13,870 | INFO | main.py | load_audios | 85 | Successfully loaded data, total count: 20
+ 2022-03-09 17:22:13,898 | INFO | main.py | count_audio | 147 | Successfully count the number of data!
+ 2022-03-09 17:22:13,918 | INFO | main.py | audio_path | 57 | Successfully load audio: ./example_audio/test.wav
+ ...
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/test.wav, distance 0.0
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_chopping.wav, distance 0.021805256605148315
+ 2022-03-09 17:22:32,580 | INFO | main.py | search_local_audio | 131 | search result http://testserver/data?audio_path=./example_audio/knife_cut_into_flesh.wav, distance 0.052762262523174286
+ ...
+ 2022-03-09 17:22:32,582 | INFO | main.py | search_local_audio | 135 | Successfully searched similar audio!
+ 2022-03-09 17:22:33,658 | INFO | main.py | drop_tables | 159 | Successfully drop tables in Milvus and MySQL!
+ ```
+ - 前端测试(可选)
+
+ 在浏览器中输入 127.0.0.1:8068 访问前端页面
+
+ 注:如果浏览器和服务不在同一台机器上,那么 IP 需要修改成服务所在的机器 IP,并且docker-compose.yaml 中相应的 API_URL 也要修改,并重新起服务即可
+
+ - 上传音频
+
+ 下载数据并解压到一文件夹,假设为 /home/speech/data,那么在上传页面地址栏输入 /home/speech/data 进行数据上传
+
+ ![](./img/insert.png)
+
+ - 检索相似音频
+
+ 选择左上角放大镜,点击 “Default Target Audio File” 按钮,上传测试音频,接着你将看到检索结果
+
+ ![](./img/search.png)
+
+### 4. 结果
+
+机器配置:
+- 操作系统: CentOS release 7.6
+- 内核:4.17.11-1.el7.elrepo.x86_64
+- 处理器:Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
+- 内存:132G
+
+数据集:
+- CN-Celeb, 训练集 65万, 测试集 1万,向量维度 192,距离计算方式 L2
+
+召回和耗时统计如下图:
+
+ ![](./img/result.png)
+
+基于 milvus 的检索框架在召回率 90% 的前提下,检索耗时约 2.9 毫秒,加上特征提取(Embedding)耗时约 500毫秒(测试音频时长约 5秒),即单条音频测试总共耗时约 503 毫秒,可以满足大多数应用场景
+
+### 5. 预训练模型
+
+以下是 PaddleSpeech 提供的预训练模型列表:
+
+| 模型 | 采样率
+| :--- | :---:
+| ecapa_tdnn| 16000
+| panns_cnn6| 32000
+| panns_cnn10| 32000
+| panns_cnn14| 32000
diff --git a/demos/audio_searching/docker-compose.yaml b/demos/audio_searching/docker-compose.yaml
new file mode 100644
index 00000000..8916e76f
--- /dev/null
+++ b/demos/audio_searching/docker-compose.yaml
@@ -0,0 +1,88 @@
+version: '3.5'
+
+services:
+ etcd:
+ container_name: milvus-etcd
+ image: quay.io/coreos/etcd:v3.5.0
+ networks:
+ app_net:
+ environment:
+ - ETCD_AUTO_COMPACTION_MODE=revision
+ - ETCD_AUTO_COMPACTION_RETENTION=1000
+ - ETCD_QUOTA_BACKEND_BYTES=4294967296
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+ command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+
+ minio:
+ container_name: milvus-minio
+ image: minio/minio:RELEASE.2020-12-03T00-03-10Z
+ networks:
+ app_net:
+ environment:
+ MINIO_ACCESS_KEY: minioadmin
+ MINIO_SECRET_KEY: minioadmin
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+ command: minio server /minio_data
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+ standalone:
+ container_name: milvus-standalone
+ image: milvusdb/milvus:v2.0.1
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.10
+ command: ["milvus", "run", "standalone"]
+ environment:
+ ETCD_ENDPOINTS: etcd:2379
+ MINIO_ADDRESS: minio:9000
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+ ports:
+ - "19530:19530"
+ depends_on:
+ - "etcd"
+ - "minio"
+
+ mysql:
+ container_name: audio-mysql
+ image: mysql:5.7
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.11
+ environment:
+ - MYSQL_ROOT_PASSWORD=123456
+ volumes:
+ - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
+ ports:
+ - "3306:3306"
+
+ webclient:
+ container_name: audio-webclient
+ image: qingen1/paddlespeech-audio-search-client:2.3
+ networks:
+ app_net:
+ ipv4_address: 172.16.23.13
+ environment:
+ API_URL: 'http://127.0.0.1:8002'
+ ports:
+ - "8068:80"
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost/"]
+ interval: 30s
+ timeout: 20s
+ retries: 3
+
+networks:
+ app_net:
+ driver: bridge
+ ipam:
+ driver: default
+ config:
+ - subnet: 172.16.23.0/24
+ gateway: 172.16.23.1
diff --git a/demos/audio_searching/img/audio_searching.png b/demos/audio_searching/img/audio_searching.png
new file mode 100644
index 00000000..b145dd49
Binary files /dev/null and b/demos/audio_searching/img/audio_searching.png differ
diff --git a/demos/audio_searching/img/insert.png b/demos/audio_searching/img/insert.png
new file mode 100644
index 00000000..b9e766bd
Binary files /dev/null and b/demos/audio_searching/img/insert.png differ
diff --git a/demos/audio_searching/img/result.png b/demos/audio_searching/img/result.png
new file mode 100644
index 00000000..c4efc0c7
Binary files /dev/null and b/demos/audio_searching/img/result.png differ
diff --git a/demos/audio_searching/img/search.png b/demos/audio_searching/img/search.png
new file mode 100644
index 00000000..26bcd9bd
Binary files /dev/null and b/demos/audio_searching/img/search.png differ
diff --git a/demos/audio_searching/requirements.txt b/demos/audio_searching/requirements.txt
new file mode 100644
index 00000000..9e73361b
--- /dev/null
+++ b/demos/audio_searching/requirements.txt
@@ -0,0 +1,12 @@
+soundfile==0.10.3.post1
+librosa==0.8.0
+numpy
+pymysql
+fastapi
+uvicorn
+diskcache==5.2.1
+pymilvus==2.0.1
+python-multipart
+typing
+starlette
+pydantic
\ No newline at end of file
diff --git a/demos/audio_searching/src/config.py b/demos/audio_searching/src/config.py
new file mode 100644
index 00000000..72a8fb4b
--- /dev/null
+++ b/demos/audio_searching/src/config.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+############### Milvus Configuration ###############
+MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
+MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
+VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
+INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
+METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
+DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
+TOP_K = int(os.getenv("TOP_K", "10"))
+
+############### MySQL Configuration ###############
+MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
+MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
+MYSQL_USER = os.getenv("MYSQL_USER", "root")
+MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
+MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
+
+############### Data Path ###############
+UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
+
+############### Number of Log Files ###############
+LOGS_NUM = int(os.getenv("logs_num", "0"))
diff --git a/demos/audio_searching/src/encode.py b/demos/audio_searching/src/encode.py
new file mode 100644
index 00000000..eba5c48c
--- /dev/null
+++ b/demos/audio_searching/src/encode.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import librosa
+import numpy as np
+from logs import LOGGER
+
+
+def get_audio_embedding(path):
+ """
+ Use vpr_inference to generate embedding of audio
+ """
+ try:
+ RESAMPLE_RATE = 16000
+ audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
+
+ # TODO add infer/python interface to get embedding, now fake it by rand
+ # vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
+ # embedding = vpr.inference(audio)
+ np.random.seed(hash(os.path.basename(path)) % 1000000)
+ embedding = np.random.rand(1, 2048)
+ embedding = embedding / np.linalg.norm(embedding)
+ embedding = embedding.tolist()[0]
+ return embedding
+ except Exception as e:
+ LOGGER.error(f"Error with embedding:{e}")
+ return None
diff --git a/demos/audio_searching/src/logs.py b/demos/audio_searching/src/logs.py
new file mode 100644
index 00000000..ba3ed069
--- /dev/null
+++ b/demos/audio_searching/src/logs.py
@@ -0,0 +1,164 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import datetime
+import logging
+import os
+import re
+import sys
+
+from config import LOGS_NUM
+
+
+class MultiprocessHandler(logging.FileHandler):
+ """
+ A handler class which writes formatted logging records to disk files
+ """
+
+ def __init__(self,
+ filename,
+ when='D',
+ backupCount=0,
+ encoding=None,
+ delay=False):
+ """
+ Open the specified file and use it as the stream for logging
+ """
+ self.prefix = filename
+ self.backupCount = backupCount
+ self.when = when.upper()
+ self.extMath = r"^\d{4}-\d{2}-\d{2}"
+
+ self.when_dict = {
+ 'S': "%Y-%m-%d-%H-%M-%S",
+ 'M': "%Y-%m-%d-%H-%M",
+ 'H': "%Y-%m-%d-%H",
+ 'D': "%Y-%m-%d"
+ }
+
+ self.suffix = self.when_dict.get(when)
+ if not self.suffix:
+ print('The specified date interval unit is invalid: ', self.when)
+ sys.exit(1)
+
+ self.filefmt = os.path.join('.', "logs",
+ f"{self.prefix}-{self.suffix}.log")
+
+ self.filePath = datetime.datetime.now().strftime(self.filefmt)
+
+ _dir = os.path.dirname(self.filefmt)
+ try:
+ if not os.path.exists(_dir):
+ os.makedirs(_dir)
+ except Exception as e:
+ print('Failed to create log file: ', e)
+ print("log_path:" + self.filePath)
+ sys.exit(1)
+
+ logging.FileHandler.__init__(self, self.filePath, 'a+', encoding, delay)
+
+ def should_change_file_to_write(self):
+ """
+ To write the file
+ """
+ _filePath = datetime.datetime.now().strftime(self.filefmt)
+ if _filePath != self.filePath:
+ self.filePath = _filePath
+ return True
+ return False
+
+ def do_change_file(self):
+ """
+ To change file states
+ """
+ self.baseFilename = os.path.abspath(self.filePath)
+ if self.stream:
+ self.stream.close()
+ self.stream = None
+
+ if not self.delay:
+ self.stream = self._open()
+ if self.backupCount > 0:
+ for s in self.get_files_to_delete():
+ os.remove(s)
+
+ def get_files_to_delete(self):
+ """
+ To delete backup files
+ """
+ dir_name, _ = os.path.split(self.baseFilename)
+ file_names = os.listdir(dir_name)
+ result = []
+ prefix = self.prefix + '-'
+ for file_name in file_names:
+ if file_name[:len(prefix)] == prefix:
+ suffix = file_name[len(prefix):-4]
+ if re.compile(self.extMath).match(suffix):
+ result.append(os.path.join(dir_name, file_name))
+ result.sort()
+
+ if len(result) < self.backupCount:
+ result = []
+ else:
+ result = result[:len(result) - self.backupCount]
+ return result
+
+ def emit(self, record):
+ """
+ Emit a record
+ """
+ try:
+ if self.should_change_file_to_write():
+ self.do_change_file()
+ logging.FileHandler.emit(self, record)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ self.handleError(record)
+
+
+def write_log():
+ """
+ Init a logger
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+ # formatter = '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(module)s | %(lineno)s | %(message)s'
+ fmt = logging.Formatter(
+ '%(asctime)s | %(levelname)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s'
+ )
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(logging.INFO)
+ stream_handler.setFormatter(fmt)
+
+ log_name = "audio-searching"
+ file_handler = MultiprocessHandler(log_name, when='D', backupCount=LOGS_NUM)
+ file_handler.setLevel(logging.DEBUG)
+ file_handler.setFormatter(fmt)
+ file_handler.do_change_file()
+
+ logger.addHandler(stream_handler)
+ logger.addHandler(file_handler)
+
+ return logger
+
+
+LOGGER = write_log()
+
+if __name__ == "__main__":
+ message = 'test writing logs'
+ LOGGER.info(message)
+ LOGGER.debug(message)
+ LOGGER.error(message)
diff --git a/demos/audio_searching/src/main.py b/demos/audio_searching/src/main.py
new file mode 100644
index 00000000..db091a39
--- /dev/null
+++ b/demos/audio_searching/src/main.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Optional
+
+import uvicorn
+from config import UPLOAD_PATH
+from diskcache import Cache
+from fastapi import FastAPI
+from fastapi import File
+from fastapi import UploadFile
+from logs import LOGGER
+from milvus_helpers import MilvusHelper
+from mysql_helpers import MySQLHelper
+from operations.count import do_count
+from operations.drop import do_drop
+from operations.load import do_load
+from operations.search import do_search
+from pydantic import BaseModel
+from starlette.middleware.cors import CORSMiddleware
+from starlette.requests import Request
+from starlette.responses import FileResponse
+
+app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"])
+
+MODEL = None
+MILVUS_CLI = MilvusHelper()
+MYSQL_CLI = MySQLHelper()
+
+# Mkdir 'tmp/audio-data'
+if not os.path.exists(UPLOAD_PATH):
+ os.makedirs(UPLOAD_PATH)
+ LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
+
+
+@app.get('/data')
+def audio_path(audio_path):
+ # Get the audio file
+ try:
+ LOGGER.info(f"Successfully load audio: {audio_path}")
+ return FileResponse(audio_path)
+ except Exception as e:
+ LOGGER.error(f"upload audio error: {e}")
+ return {'status': False, 'msg': e}, 400
+
+
+@app.get('/progress')
+def get_progress():
+ # Get the progress of dealing with data
+ try:
+ cache = Cache('./tmp')
+ return f"current: {cache['current']}, total: {cache['total']}"
+ except Exception as e:
+ LOGGER.error(f"Upload data error: {e}")
+ return {'status': False, 'msg': e}, 400
+
+
+class Item(BaseModel):
+ Table: Optional[str] = None
+ File: str
+
+
+@app.post('/audio/load')
+async def load_audios(item: Item):
+ # Insert all the audio files under the file path to Milvus/MySQL
+ try:
+ total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
+ LOGGER.info(f"Successfully loaded data, total count: {total_num}")
+ return {'status': True, 'msg': "Successfully loaded data!"}
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/search')
+async def search_audio(request: Request,
+ table_name: str=None,
+ audio: UploadFile=File(...)):
+ # Search the uploaded audio in Milvus/MySQL
+ try:
+ # Save the upload data to server.
+ content = await audio.read()
+ query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
+ with open(query_audio_path, "wb+") as f:
+ f.write(content)
+ host = request.headers['host']
+ _, paths, distances = do_search(host, table_name, query_audio_path,
+ MILVUS_CLI, MYSQL_CLI)
+ names = []
+ for path, score in zip(paths, distances):
+ names.append(os.path.basename(path))
+ LOGGER.info(f"search result {path}, score {score}")
+ res = dict(zip(paths, zip(names, distances)))
+ # Sort results by distance metric, closest distances first
+ res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
+ LOGGER.info("Successfully searched similar audio!")
+ return res
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/search/local')
+async def search_local_audio(request: Request,
+ query_audio_path: str,
+ table_name: str=None):
+ # Search the uploaded audio in Milvus/MySQL
+ try:
+ host = request.headers['host']
+ _, paths, distances = do_search(host, table_name, query_audio_path,
+ MILVUS_CLI, MYSQL_CLI)
+ names = []
+ for path, score in zip(paths, distances):
+ names.append(os.path.basename(path))
+ LOGGER.info(f"search result {path}, score {score}")
+ res = dict(zip(paths, zip(names, distances)))
+ # Sort results by distance metric, closest distances first
+ res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
+ LOGGER.info("Successfully searched similar audio!")
+ return res
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.get('/audio/count')
+async def count_audio(table_name: str=None):
+ # Returns the total number of vectors in the system
+ try:
+ num = do_count(table_name, MILVUS_CLI)
+ LOGGER.info("Successfully count the number of data!")
+ return num
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+@app.post('/audio/drop')
+async def drop_tables(table_name: str=None):
+ # Delete the collection of Milvus and MySQL
+ try:
+ status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
+ LOGGER.info("Successfully drop tables in Milvus and MySQL!")
+ return status
+ except Exception as e:
+ LOGGER.error(e)
+ return {'status': False, 'msg': e}, 400
+
+
+if __name__ == '__main__':
+ uvicorn.run(app=app, host='0.0.0.0', port=8002)
diff --git a/demos/audio_searching/src/milvus_helpers.py b/demos/audio_searching/src/milvus_helpers.py
new file mode 100644
index 00000000..1699e892
--- /dev/null
+++ b/demos/audio_searching/src/milvus_helpers.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import METRIC_TYPE
+from config import MILVUS_HOST
+from config import MILVUS_PORT
+from config import VECTOR_DIMENSION
+from logs import LOGGER
+from pymilvus import Collection
+from pymilvus import CollectionSchema
+from pymilvus import connections
+from pymilvus import DataType
+from pymilvus import FieldSchema
+from pymilvus import utility
+
+
+class MilvusHelper:
+ """
+ the basic operations of PyMilvus
+
+ # This example shows how to:
+ # 1. connect to Milvus server
+ # 2. create a collection
+ # 3. insert entities
+ # 4. create index
+ # 5. search
+ # 6. delete a collection
+
+ """
+
+ def __init__(self):
+ try:
+ self.collection = None
+ connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
+ LOGGER.debug(
+ f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
+ )
+ except Exception as e:
+ LOGGER.error(f"Failed to connect Milvus: {e}")
+ sys.exit(1)
+
+ def set_collection(self, collection_name):
+ try:
+ if self.has_collection(collection_name):
+ self.collection = Collection(name=collection_name)
+ else:
+ raise Exception(
+ f"There is no collection named:{collection_name}")
+ except Exception as e:
+ LOGGER.error(f"Failed to set collection in Milvus: {e}")
+ sys.exit(1)
+
+ def has_collection(self, collection_name):
+ # Return if Milvus has the collection
+ try:
+ return utility.has_collection(collection_name)
+ except Exception as e:
+ LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
+ sys.exit(1)
+
+ def create_collection(self, collection_name):
+ # Create milvus collection if not exists
+ try:
+ if not self.has_collection(collection_name):
+ field1 = FieldSchema(
+ name="id",
+ dtype=DataType.INT64,
+ descrition="int64",
+ is_primary=True,
+ auto_id=True)
+ field2 = FieldSchema(
+ name="embedding",
+ dtype=DataType.FLOAT_VECTOR,
+ descrition="speaker embeddings",
+ dim=VECTOR_DIMENSION,
+ is_primary=False)
+ schema = CollectionSchema(
+ fields=[field1, field2], description="embeddings info")
+ self.collection = Collection(
+ name=collection_name, schema=schema)
+ LOGGER.debug(f"Create Milvus collection: {collection_name}")
+ else:
+ self.set_collection(collection_name)
+ return "OK"
+ except Exception as e:
+ LOGGER.error(f"Failed to create collection in Milvus: {e}")
+ sys.exit(1)
+
+ def insert(self, collection_name, vectors):
+ # Batch insert vectors to milvus collection
+ try:
+ self.create_collection(collection_name)
+ data = [vectors]
+ self.set_collection(collection_name)
+ mr = self.collection.insert(data)
+ ids = mr.primary_keys
+ self.collection.load()
+ LOGGER.debug(
+ f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
+ )
+ return ids
+ except Exception as e:
+ LOGGER.error(f"Failed to insert data to Milvus: {e}")
+ sys.exit(1)
+
+ def create_index(self, collection_name):
+ # Create IVF_FLAT index on milvus collection
+ try:
+ self.set_collection(collection_name)
+ default_index = {
+ "index_type": "IVF_SQ8",
+ "metric_type": METRIC_TYPE,
+ "params": {
+ "nlist": 16384
+ }
+ }
+ status = self.collection.create_index(
+ field_name="embedding", index_params=default_index)
+ if not status.code:
+ LOGGER.debug(
+ f"Successfully create index in collection:{collection_name} with param:{default_index}"
+ )
+ return status
+ else:
+ raise Exception(status.message)
+ except Exception as e:
+ LOGGER.error(f"Failed to create index: {e}")
+ sys.exit(1)
+
+ def delete_collection(self, collection_name):
+ # Delete Milvus collection
+ try:
+ self.set_collection(collection_name)
+ self.collection.drop()
+ LOGGER.debug("Successfully drop collection!")
+ return "ok"
+ except Exception as e:
+ LOGGER.error(f"Failed to drop collection: {e}")
+ sys.exit(1)
+
+ def search_vectors(self, collection_name, vectors, top_k):
+ # Search vector in milvus collection
+ try:
+ self.set_collection(collection_name)
+ search_params = {
+ "metric_type": METRIC_TYPE,
+ "params": {
+ "nprobe": 16
+ }
+ }
+ res = self.collection.search(
+ vectors,
+ anns_field="embedding",
+ param=search_params,
+ limit=top_k)
+ LOGGER.debug(f"Successfully search in collection: {res}")
+ return res
+ except Exception as e:
+ LOGGER.error(f"Failed to search vectors in Milvus: {e}")
+ sys.exit(1)
+
+ def count(self, collection_name):
+ # Get the number of milvus collection
+ try:
+ self.set_collection(collection_name)
+ num = self.collection.num_entities
+ LOGGER.debug(
+ f"Successfully get the num:{num} of the collection:{collection_name}"
+ )
+ return num
+ except Exception as e:
+ LOGGER.error(f"Failed to count vectors in Milvus: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/mysql_helpers.py b/demos/audio_searching/src/mysql_helpers.py
new file mode 100644
index 00000000..30383839
--- /dev/null
+++ b/demos/audio_searching/src/mysql_helpers.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+import pymysql
+from config import MYSQL_DB
+from config import MYSQL_HOST
+from config import MYSQL_PORT
+from config import MYSQL_PWD
+from config import MYSQL_USER
+from logs import LOGGER
+
+
+class MySQLHelper():
+ """
+ the basic operations of PyMySQL
+
+ # This example shows how to:
+ # 1. connect to MySQL server
+ # 2. create a table
+ # 3. insert data to table
+ # 4. search by milvus ids
+ # 5. delete table
+ """
+
+ def __init__(self):
+ self.conn = pymysql.connect(
+ host=MYSQL_HOST,
+ user=MYSQL_USER,
+ port=MYSQL_PORT,
+ password=MYSQL_PWD,
+ database=MYSQL_DB,
+ local_infile=True)
+ self.cursor = self.conn.cursor()
+
+ def test_connection(self):
+ try:
+ self.conn.ping()
+ except Exception:
+ self.conn = pymysql.connect(
+ host=MYSQL_HOST,
+ user=MYSQL_USER,
+ port=MYSQL_PORT,
+ password=MYSQL_PWD,
+ database=MYSQL_DB,
+ local_infile=True)
+ self.cursor = self.conn.cursor()
+
+ def create_mysql_table(self, table_name):
+ # Create mysql table if not exists
+ self.test_connection()
+ sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
+ try:
+ self.cursor.execute(sql)
+ LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def load_data_to_mysql(self, table_name, data):
+ # Batch insert (Milvus_ids, img_path) to mysql
+ self.test_connection()
+ sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
+ try:
+ self.cursor.executemany(sql, data)
+ self.conn.commit()
+ LOGGER.debug(
+ f"MYSQL loads data to table: {table_name} successfully")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def search_by_milvus_ids(self, ids, table_name):
+ # Get the img_path according to the milvus ids
+ self.test_connection()
+ str_ids = str(ids).replace('[', '').replace(']', '')
+ sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
+ try:
+ self.cursor.execute(sql)
+ results = self.cursor.fetchall()
+ results = [res[0] for res in results]
+ LOGGER.debug("MYSQL search by milvus id.")
+ return results
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def delete_table(self, table_name):
+ # Delete mysql table if exists
+ self.test_connection()
+ sql = "drop table if exists " + table_name + ";"
+ try:
+ self.cursor.execute(sql)
+ LOGGER.debug(f"MYSQL delete table:{table_name}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def delete_all_data(self, table_name):
+ # Delete all the data in mysql table
+ self.test_connection()
+ sql = 'delete from ' + table_name + ';'
+ try:
+ self.cursor.execute(sql)
+ self.conn.commit()
+ LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
+
+ def count_table(self, table_name):
+ # Get the number of mysql table
+ self.test_connection()
+ sql = "select count(milvus_id) from " + table_name + ";"
+ try:
+ self.cursor.execute(sql)
+ results = self.cursor.fetchall()
+ LOGGER.debug(f"MYSQL count table:{table_name}")
+ return results[0][0]
+ except Exception as e:
+ LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/operations/__init__.py b/demos/audio_searching/src/operations/__init__.py
new file mode 100644
index 00000000..97043fd7
--- /dev/null
+++ b/demos/audio_searching/src/operations/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/demos/audio_searching/src/operations/count.py b/demos/audio_searching/src/operations/count.py
new file mode 100644
index 00000000..9a1f4208
--- /dev/null
+++ b/demos/audio_searching/src/operations/count.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from logs import LOGGER
+
+
+def do_count(table_name, milvus_cli):
+ """
+ Returns the total number of vectors in the system
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ try:
+ if not milvus_cli.has_collection(table_name):
+ return None
+ num = milvus_cli.count(table_name)
+ return num
+ except Exception as e:
+ LOGGER.error(f"Error attempting to count table {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/operations/drop.py b/demos/audio_searching/src/operations/drop.py
new file mode 100644
index 00000000..f8278ddd
--- /dev/null
+++ b/demos/audio_searching/src/operations/drop.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from logs import LOGGER
+
+
+def do_drop(table_name, milvus_cli, mysql_cli):
+ """
+ Delete the collection of Milvus and MySQL
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ try:
+ if not milvus_cli.has_collection(table_name):
+ return "Collection is not exist"
+ status = milvus_cli.delete_collection(table_name)
+ mysql_cli.delete_table(table_name)
+ return status
+ except Exception as e:
+ LOGGER.error(f"Error attempting to drop table: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py
new file mode 100644
index 00000000..7a295bf3
--- /dev/null
+++ b/demos/audio_searching/src/operations/load.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+from config import DEFAULT_TABLE
+from diskcache import Cache
+from encode import get_audio_embedding
+from logs import LOGGER
+
+
+def get_audios(path):
+ """
+ List all wav and aif files recursively under the path folder.
+ """
+ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
+ return [
+ item
+ for sublist in [[os.path.join(dir, file) for file in files]
+ for dir, _, files in list(os.walk(path))]
+ for item in sublist if os.path.splitext(item)[1] in supported_formats
+ ]
+
+
+def extract_features(audio_dir):
+ """
+ Get the vector of audio
+ """
+ try:
+ cache = Cache('./tmp')
+ feats = []
+ names = []
+ audio_list = get_audios(audio_dir)
+ total = len(audio_list)
+ cache['total'] = total
+ for i, audio_path in enumerate(audio_list):
+ norm_feat = get_audio_embedding(audio_path)
+ if norm_feat is None:
+ continue
+ feats.append(norm_feat)
+ names.append(audio_path.encode())
+ cache['current'] = i + 1
+ print(
+ f"Extracting feature from audio No. {i + 1} , {total} audios in total"
+ )
+ return feats, names
+ except Exception as e:
+ LOGGER.error(f"Error with extracting feature from audio {e}")
+ sys.exit(1)
+
+
+def format_data(ids, names):
+ """
+ Combine the id of the vector and the name of the audio into a list
+ """
+ data = []
+ for i in range(len(ids)):
+ value = (str(ids[i]), names[i])
+ data.append(value)
+ return data
+
+
+def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
+ """
+ Import vectors to Milvus and data to Mysql respectively
+ """
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ vectors, names = extract_features(audio_dir)
+ ids = milvus_cli.insert(table_name, vectors)
+ milvus_cli.create_index(table_name)
+ mysql_cli.create_mysql_table(table_name)
+ mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
+ return len(ids)
diff --git a/demos/audio_searching/src/operations/search.py b/demos/audio_searching/src/operations/search.py
new file mode 100644
index 00000000..9cf48abf
--- /dev/null
+++ b/demos/audio_searching/src/operations/search.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from config import DEFAULT_TABLE
+from config import TOP_K
+from encode import get_audio_embedding
+from logs import LOGGER
+
+
+def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
+ """
+ Search the uploaded audio in Milvus/MySQL
+ """
+ try:
+ if not table_name:
+ table_name = DEFAULT_TABLE
+ feat = get_audio_embedding(audio_path)
+ vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
+ vids = [str(x.id) for x in vectors[0]]
+ paths = mysql_cli.search_by_milvus_ids(vids, table_name)
+ distances = [x.distance for x in vectors[0]]
+ for i in range(len(paths)):
+ tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
+ paths[i] = tmp
+ distances[i] = (1 - distances[i]) * 100
+ return vids, paths, distances
+ except Exception as e:
+ LOGGER.error(f"Error with search: {e}")
+ sys.exit(1)
diff --git a/demos/audio_searching/src/test_main.py b/demos/audio_searching/src/test_main.py
new file mode 100644
index 00000000..331208ff
--- /dev/null
+++ b/demos/audio_searching/src/test_main.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import zipfile
+
+import gdown
+from fastapi.testclient import TestClient
+from main import app
+
+client = TestClient(app)
+
+
+def download_audio_data():
+ """
+ download audio data
+ """
+ url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
+ gdown.download(url)
+
+ with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
+ zip_ref.extractall('./example_audio')
+
+
+def test_drop():
+ """
+ Delete the collection of Milvus and MySQL
+ """
+ response = client.post("/audio/drop")
+ assert response.status_code == 200
+
+
+def test_load():
+ """
+ Insert all the audio files under the file path to Milvus/MySQL
+ """
+ response = client.post("/audio/load", json={"File": "./example_audio"})
+ assert response.status_code == 200
+ assert response.json() == {
+ 'status': True,
+ 'msg': "Successfully loaded data!"
+ }
+
+
+def test_progress():
+ """
+ Get the progress of dealing with data
+ """
+ response = client.get("/progress")
+ assert response.status_code == 200
+ assert response.json() == "current: 20, total: 20"
+
+
+def test_count():
+ """
+ Returns the total number of vectors in the system
+ """
+ response = client.get("audio/count")
+ assert response.status_code == 200
+ assert response.json() == 20
+
+
+def test_search():
+ """
+ Search the uploaded audio in Milvus/MySQL
+ """
+ response = client.post(
+ "/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
+ assert response.status_code == 200
+ assert len(response.json()) == 10
+
+
+def test_data():
+ """
+ Get the audio file
+ """
+ response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
+ assert response.status_code == 200
+
+
+if __name__ == "__main__":
+ download_audio_data()
+ test_load()
+ test_count()
+ test_search()
+ test_drop()
diff --git a/demos/speech_server/.gitignore b/demos/speech_server/.gitignore
new file mode 100644
index 00000000..d8dd7532
--- /dev/null
+++ b/demos/speech_server/.gitignore
@@ -0,0 +1 @@
+*.wav
diff --git a/docs/source/reference.md b/docs/source/reference.md
index a8327e92..f1a02d20 100644
--- a/docs/source/reference.md
+++ b/docs/source/reference.md
@@ -35,3 +35,7 @@ We borrowed a lot of code from these repos to build `model` and `engine`, thanks
* [librosa](https://github.com/librosa/librosa/blob/main/LICENSE.md)
- ISC License
- Audio feature
+
+* [ThreadPool](https://github.com/progschj/ThreadPool/blob/master/COPYING)
+- zlib License
+- ThreadPool
diff --git a/docs/source/released_model.md b/docs/source/released_model.md
index ffe721b8..52b386da 100644
--- a/docs/source/released_model.md
+++ b/docs/source/released_model.md
@@ -49,11 +49,12 @@ Model Type | Dataset| Example Link | Pretrained Models| Static Models|Size (stat
WaveFlow| LJSpeech |[waveflow-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0)|[waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/waveflow/waveflow_ljspeech_ckpt_0.3.zip)|||
Parallel WaveGAN| CSMSC |[PWGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1)|[pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip)|[pwg_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip)|5.1MB|
Parallel WaveGAN| LJSpeech |[PWGAN-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc1)|[pwg_ljspeech_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip)|||
-Parallel WaveGAN|AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)|||
+Parallel WaveGAN| AISHELL-3 |[PWGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1)|[pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip)|||
Parallel WaveGAN| VCTK |[PWGAN-vctk](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/voc1)|[pwg_vctk_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip)|||
|Multi Band MelGAN | CSMSC |[MB MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc3) | [mb_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip)
[mb_melgan_baker_finetune_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip)|[mb_melgan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip) |8.2MB|
Style MelGAN | CSMSC |[Style MelGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc4)|[style_melgan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip)| | |
HiFiGAN | CSMSC |[HiFiGAN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc5)|[hifigan_csmsc_ckpt_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip)|[hifigan_csmsc_static_0.1.1.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip)|50MB|
+HiFiGAN | AISHELL-3 |[HiFiGAN-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc5)|[hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip)|||
WaveRNN | CSMSC |[WaveRNN-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc6)|[wavernn_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip)|[wavernn_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_static_0.2.0.zip)|18MB|
diff --git a/examples/aishell3/tts3/local/synthesize.sh b/examples/aishell3/tts3/local/synthesize.sh
index b1fc96a2..d3978833 100755
--- a/examples/aishell3/tts3/local/synthesize.sh
+++ b/examples/aishell3/tts3/local/synthesize.sh
@@ -4,18 +4,44 @@ config_path=$1
train_output_path=$2
ckpt_name=$3
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/../synthesize.py \
- --am=fastspeech2_aishell3 \
- --am_config=${config_path} \
- --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
- --voc=pwgan_aishell3 \
- --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
- --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
- --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
- --test_metadata=dump/test/norm/metadata.jsonl \
- --output_dir=${train_output_path}/test \
- --phones_dict=dump/phone_id_map.txt \
- --speaker_dict=dump/speaker_id_map.txt
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=pwgan_aishell3 \
+ --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
+ --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
+ --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=hifigan_aishell3 \
+ --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pd \
+ --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt
+fi
+
diff --git a/examples/aishell3/tts3/local/synthesize_e2e.sh b/examples/aishell3/tts3/local/synthesize_e2e.sh
index 60e1a5ce..ff3608be 100755
--- a/examples/aishell3/tts3/local/synthesize_e2e.sh
+++ b/examples/aishell3/tts3/local/synthesize_e2e.sh
@@ -4,21 +4,50 @@ config_path=$1
train_output_path=$2
ckpt_name=$3
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/../synthesize_e2e.py \
- --am=fastspeech2_aishell3 \
- --am_config=${config_path} \
- --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
- --am_stat=dump/train/speech_stats.npy \
- --voc=pwgan_aishell3 \
- --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
- --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
- --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
- --lang=zh \
- --text=${BIN_DIR}/../sentences.txt \
- --output_dir=${train_output_path}/test_e2e \
- --phones_dict=dump/phone_id_map.txt \
- --speaker_dict=dump/speaker_id_map.txt \
- --spk_id=0 \
- --inference_dir=${train_output_path}/inference
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=dump/train/speech_stats.npy \
+ --voc=pwgan_aishell3 \
+ --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
+ --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
+ --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ --speaker_dict=dump/speaker_id_map.txt \
+ --spk_id=0 \
+ --inference_dir=${train_output_path}/inference
+fi
+
+# hifigan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "in hifigan syn_e2e"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --am=fastspeech2_aishell3 \
+ --am_config=${config_path} \
+ --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
+ --am_stat=fastspeech2_nosil_aishell3_ckpt_0.4/speech_stats.npy \
+ --voc=hifigan_aishell3 \
+ --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \
+ --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=fastspeech2_nosil_aishell3_ckpt_0.4/phone_id_map.txt \
+ --speaker_dict=fastspeech2_nosil_aishell3_ckpt_0.4/speaker_id_map.txt \
+ --spk_id=0 \
+ --inference_dir=${train_output_path}/inference
+ fi
diff --git a/examples/aishell3/vc0/local/preprocess.sh b/examples/aishell3/vc0/local/preprocess.sh
index 069cf94c..e458c706 100755
--- a/examples/aishell3/vc0/local/preprocess.sh
+++ b/examples/aishell3/vc0/local/preprocess.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-stage=3
+stage=0
stop_stage=100
config_path=$1
diff --git a/examples/aishell3/voc1/run.sh b/examples/aishell3/voc1/run.sh
index 4f426ea0..cab1ac38 100755
--- a/examples/aishell3/voc1/run.sh
+++ b/examples/aishell3/voc1/run.sh
@@ -3,7 +3,7 @@
set -e
source path.sh
-gpus=0
+gpus=0,1
stage=0
stop_stage=100
diff --git a/examples/aishell3/voc5/README.md b/examples/aishell3/voc5/README.md
index 7cd0b396..ebe2530b 100644
--- a/examples/aishell3/voc5/README.md
+++ b/examples/aishell3/voc5/README.md
@@ -135,8 +135,22 @@ optional arguments:
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
4. `--output-dir` is the directory to save the synthesized audio files.
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
-
## Pretrained Models
+The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip).
+
+
+Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
+:-------------:| :------------:| :-----: | :-----: | :--------:
+default| 1(gpu) x 2500000|24.060|0.1068|7.499
+
+HiFiGAN checkpoint contains files listed below.
+
+```text
+hifigan_aishell3_ckpt_0.2.0
+├── default.yaml # default config used to train hifigan
+├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan
+└── snapshot_iter_2500000.pdz # generator parameters of hifigan
+```
## Acknowledgement
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
diff --git a/paddleaudio/setup.py b/paddleaudio/setup.py
index 6c757d33..930f86e4 100644
--- a/paddleaudio/setup.py
+++ b/paddleaudio/setup.py
@@ -61,6 +61,7 @@ def remove_version_py(filename='paddleaudio/__init__.py'):
if "__version__" not in line:
f.write(line)
+
remove_version_py()
write_version_py()
diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py
index d7dcc90c..f7d64b9a 100644
--- a/paddlespeech/cli/utils.py
+++ b/paddlespeech/cli/utils.py
@@ -192,7 +192,7 @@ class ConfigCache:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
- except:
+ except Exception as e:
self.flush()
@property
diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py
index 7e7f03b2..f6a7f429 100644
--- a/paddlespeech/server/bin/paddlespeech_server.py
+++ b/paddlespeech/server/bin/paddlespeech_server.py
@@ -174,7 +174,7 @@ class ServerStatsExecutor():
"Failed to get the table of TTS pretrained models supported in the service."
)
return False
-
+
elif self.task == 'cls':
try:
from paddlespeech.cli.cls.infer import pretrained_models
diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py
index 426b7617..abb1eb4e 100644
--- a/paddlespeech/t2s/exps/synthesize.py
+++ b/paddlespeech/t2s/exps/synthesize.py
@@ -156,6 +156,7 @@ def parse_args():
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
+ 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc'
],
help='Choose vocoder type of tts task.')
diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py
index 3d01bdb0..f5214d4a 100644
--- a/paddlespeech/t2s/exps/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/synthesize_e2e.py
@@ -180,9 +180,17 @@ def parse_args():
type=str,
default='pwgan_csmsc',
choices=[
- 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
- 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc',
- 'wavernn_csmsc'
+ 'pwgan_csmsc',
+ 'pwgan_ljspeech',
+ 'pwgan_aishell3',
+ 'pwgan_vctk',
+ 'mb_melgan_csmsc',
+ 'style_melgan_csmsc',
+ 'hifigan_csmsc',
+ 'hifigan_ljspeech',
+ 'hifigan_aishell3',
+ 'hifigan_vctk',
+ 'wavernn_csmsc',
],
help='Choose vocoder type of tts task.')
parser.add_argument(
diff --git a/speechx/.gitignore b/speechx/.gitignore
new file mode 100644
index 00000000..e0c61847
--- /dev/null
+++ b/speechx/.gitignore
@@ -0,0 +1 @@
+tools/valgrind*
diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt
index e003136a..f1330d1d 100644
--- a/speechx/CMakeLists.txt
+++ b/speechx/CMakeLists.txt
@@ -2,18 +2,32 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(paddlespeech VERSION 0.1)
+set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake")
+
set(CMAKE_VERBOSE_MAKEFILE on)
+
# set std-14
set(CMAKE_CXX_STANDARD 14)
-# include file
+# cmake dir
+set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
+
+# Modules
+list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
+list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent)
include(ExternalProject)
+
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
+# compiler option
+# Keep the same with openfst, -fPIC or -fpic
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g")
+SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
+SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")
###############################################################################
# Option Configurations
@@ -25,91 +39,92 @@ option(TEST_DEBUG "option for debug" OFF)
###############################################################################
# Include third party
###############################################################################
-# #example for include third party
-# FetchContent_Declare()
-# # FetchContent_MakeAvailable was not added until CMake 3.14
+# example for include third party
+# FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
+# gflags
+include(gflags)
+
+# glog
+include(glog)
+
+# gtest
+include(gtest)
+
# ABSEIL-CPP
-include(FetchContent)
-FetchContent_Declare(
- absl
- GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
- GIT_TAG "20210324.1"
-)
-FetchContent_MakeAvailable(absl)
+include(absl)
# libsndfile
-include(FetchContent)
-FetchContent_Declare(
- libsndfile
- GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
- GIT_TAG "1.0.31"
-)
-FetchContent_MakeAvailable(libsndfile)
+include(libsndfile)
-# gflags
-FetchContent_Declare(
- gflags
- URL https://github.com/gflags/gflags/archive/v2.2.1.zip
- URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
-)
-FetchContent_MakeAvailable(gflags)
-include_directories(${gflags_BINARY_DIR}/include)
+# boost
+# include(boost) # not work
+set(boost_SOURCE_DIR ${fc_patch}/boost-src)
+set(BOOST_ROOT ${boost_SOURCE_DIR})
+# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
-# glog
-FetchContent_Declare(
- glog
- URL https://github.com/google/glog/archive/v0.4.0.zip
- URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
-)
-FetchContent_MakeAvailable(glog)
-include_directories(${glog_BINARY_DIR})
+# Eigen
+include(eigen)
+find_package(Eigen3 REQUIRED)
-# gtest
-FetchContent_Declare(googletest
- URL https://github.com/google/googletest/archive/release-1.10.0.zip
- URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
-)
-FetchContent_MakeAvailable(googletest)
+# Kenlm
+include(kenlm)
+add_dependencies(kenlm eigen boost)
+
+#openblas
+include(openblas)
# openfst
-set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
-set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
-set(openfst_PREFIX_DIR ${fc_patch}/openfst-subbuild/openfst-populate-prefix)
-ExternalProject_Add(openfst
- URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
- URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
- SOURCE_DIR ${openfst_SOURCE_DIR}
- BINARY_DIR ${openfst_BINARY_DIR}
- CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
- "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
- "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
- "LIBS=-lgflags_nothreads -lglog -lpthread"
- BUILD_COMMAND make -j 4
-)
+include(openfst)
add_dependencies(openfst gflags glog)
-link_directories(${openfst_PREFIX_DIR}/lib)
-include_directories(${openfst_PREFIX_DIR}/include)
-add_subdirectory(speechx)
-#openblas
-#set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/OpenBLAS)
-#set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
-#ExternalProject_Add(
-# OpenBLAS
-# GIT_REPOSITORY https://github.com/xianyi/OpenBLAS
-# GIT_TAG v0.3.13
-# GIT_SHALLOW TRUE
-# GIT_PROGRESS TRUE
-# CONFIGURE_COMMAND ""
-# BUILD_IN_SOURCE TRUE
-# BUILD_COMMAND make USE_LOCKING=1 USE_THREAD=0
-# INSTALL_COMMAND make PREFIX=${OpenBLAS_INSTALL_PREFIX} install
-# UPDATE_DISCONNECTED TRUE
-#)
+# paddle lib
+set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
+set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
+ExternalProject_Add(paddle
+ URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
+ URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
+ PREFIX ${paddle_PREFIX_DIR}
+ SOURCE_DIR ${paddle_SOURCE_DIR}
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+)
+
+set(PADDLE_LIB ${fc_patch}/paddle-lib)
+include_directories("${PADDLE_LIB}/paddle/include")
+set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
+include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
+
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
+link_directories("${PADDLE_LIB}/paddle/lib")
+link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
+
+##paddle with mkl
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
+set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
+include_directories("${MATH_LIB_PATH}/include")
+set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
+ ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
+set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
+include_directories("${MKLDNN_PATH}/include")
+set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+
+set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
+set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf xxhash cryptopp
+ ${EXTERNAL_LIB})
+
+
###############################################################################
# Add local library
@@ -121,4 +136,9 @@ add_subdirectory(speechx)
# if dir do not have CmakeLists.txt
#add_library(lib_name STATIC file.cc)
#target_link_libraries(lib_name item0 item1)
-#add_dependencies(lib_name depend-target)
\ No newline at end of file
+#add_dependencies(lib_name depend-target)
+
+set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
+
+add_subdirectory(speechx)
+add_subdirectory(examples)
\ No newline at end of file
diff --git a/speechx/README.md b/speechx/README.md
new file mode 100644
index 00000000..7d73b61c
--- /dev/null
+++ b/speechx/README.md
@@ -0,0 +1,61 @@
+# SpeechX -- All in One Speech Task Inference
+
+## Environment
+
+We develop under:
+* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7
+* os - Ubuntu 16.04.7 LTS
+* gcc/g++ - 8.2.0
+* cmake - 3.16.0
+
+> We make sure all things work fun under docker, and recommend using it to develop and deploy.
+
+* [How to Install Docker](https://docs.docker.com/engine/install/)
+* [A Docker Tutorial for Beginners](https://docker-curriculum.com/)
+* [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html)
+
+## Build
+
+1. First to launch docker container.
+
+```
+nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 /bin/bash
+```
+
+* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
+
+* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`.
+
+
+2. Build `speechx` and `examples`.
+
+```
+pushd /path/to/speechx
+./build.sh
+```
+
+3. Go to `examples` to have a fun.
+
+More details please see `README.md` under `examples`.
+
+
+## Valgrind (Optional)
+
+> If using docker please check `--privileged` is set when `docker run`.
+
+* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up`
+```
+apt-get install libc6-dbg
+```
+
+* Install
+
+```
+pushd tools
+./setup_valgrind.sh
+popd
+```
+
+## TODO
+
+* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result.
diff --git a/speechx/build.sh b/speechx/build.sh
new file mode 100755
index 00000000..3e9600d5
--- /dev/null
+++ b/speechx/build.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+
+# the build script had verified in the paddlepaddle docker image.
+# please follow the instruction below to install PaddlePaddle image.
+# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
+
+boost_SOURCE_DIR=$PWD/fc_patch/boost-src
+if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
+ tar xzfv boost_1_75_0.tar.gz
+ mkdir -p $PWD/fc_patch
+ mv boost_1_75_0 ${boost_SOURCE_DIR}
+ cd ${boost_SOURCE_DIR}
+ bash ./bootstrap.sh
+ ./b2
+ cd -
+ echo -e "\n"
+fi
+
+#rm -rf build
+mkdir -p build
+cd build
+
+cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
+#cmake ..
+
+make -j1
+
+cd -
diff --git a/speechx/cmake/EnableCMP0048.cmake b/speechx/cmake/EnableCMP0048.cmake
new file mode 100644
index 00000000..1b59188f
--- /dev/null
+++ b/speechx/cmake/EnableCMP0048.cmake
@@ -0,0 +1 @@
+cmake_policy(SET CMP0048 NEW)
\ No newline at end of file
diff --git a/speechx/cmake/external/absl.cmake b/speechx/cmake/external/absl.cmake
new file mode 100644
index 00000000..2c5e5af5
--- /dev/null
+++ b/speechx/cmake/external/absl.cmake
@@ -0,0 +1,16 @@
+include(FetchContent)
+
+
+set(BUILD_SHARED_LIBS OFF) # up to you
+set(BUILD_TESTING OFF) # to disable abseil test, or gtest will fail.
+set(ABSL_ENABLE_INSTALL ON) # now you can enable install rules even in subproject...
+
+FetchContent_Declare(
+ absl
+ GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
+ GIT_TAG "20210324.1"
+)
+FetchContent_MakeAvailable(absl)
+
+set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
+include_directories(${absl_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/boost.cmake b/speechx/cmake/external/boost.cmake
new file mode 100644
index 00000000..6bc97aad
--- /dev/null
+++ b/speechx/cmake/external/boost.cmake
@@ -0,0 +1,27 @@
+include(FetchContent)
+set(Boost_DEBUG ON)
+
+set(Boost_PREFIX_DIR ${fc_patch}/boost)
+set(Boost_SOURCE_DIR ${fc_patch}/boost-src)
+
+FetchContent_Declare(
+ Boost
+ URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
+ URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
+ PREFIX ${Boost_PREFIX_DIR}
+ SOURCE_DIR ${Boost_SOURCE_DIR}
+)
+
+execute_process(COMMAND bootstrap.sh WORKING_DIRECTORY ${Boost_SOURCE_DIR})
+execute_process(COMMAND b2 WORKING_DIRECTORY ${Boost_SOURCE_DIR})
+
+FetchContent_MakeAvailable(Boost)
+
+message(STATUS "boost src dir: ${Boost_SOURCE_DIR}")
+message(STATUS "boost inc dir: ${Boost_INCLUDE_DIR}")
+message(STATUS "boost bin dir: ${Boost_BINARY_DIR}")
+
+set(BOOST_ROOT ${Boost_SOURCE_DIR})
+message(STATUS "boost root dir: ${BOOST_ROOT}")
+
+include_directories(${Boost_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/eigen.cmake b/speechx/cmake/external/eigen.cmake
new file mode 100644
index 00000000..12bd3cdf
--- /dev/null
+++ b/speechx/cmake/external/eigen.cmake
@@ -0,0 +1,27 @@
+include(FetchContent)
+
+# update eigen to the commit id f612df27 on 03/16/2021
+set(EIGEN_PREFIX_DIR ${fc_patch}/eigen3)
+
+FetchContent_Declare(
+ Eigen3
+ GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
+ GIT_TAG master
+ PREFIX ${EIGEN_PREFIX_DIR}
+ GIT_SHALLOW TRUE
+ GIT_PROGRESS TRUE)
+
+set(EIGEN_BUILD_DOC OFF)
+# note: To disable eigen tests,
+# you should put this code in a add_subdirectory to avoid to change
+# BUILD_TESTING for your own project too since variables are directory
+# scoped
+set(BUILD_TESTING OFF)
+set(EIGEN_BUILD_PKGCONFIG OFF)
+set( OFF)
+FetchContent_MakeAvailable(Eigen3)
+
+message(STATUS "eigen src dir: ${Eigen3_SOURCE_DIR}")
+message(STATUS "eigen bin dir: ${Eigen3_BINARY_DIR}")
+#include_directories(${Eigen3_SOURCE_DIR})
+#link_directories(${Eigen3_BINARY_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/gflags.cmake b/speechx/cmake/external/gflags.cmake
new file mode 100644
index 00000000..66ae47f7
--- /dev/null
+++ b/speechx/cmake/external/gflags.cmake
@@ -0,0 +1,12 @@
+include(FetchContent)
+
+FetchContent_Declare(
+ gflags
+ URL https://github.com/gflags/gflags/archive/v2.2.1.zip
+ URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
+)
+
+FetchContent_MakeAvailable(gflags)
+
+# openfst need
+include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
diff --git a/speechx/cmake/external/glog.cmake b/speechx/cmake/external/glog.cmake
new file mode 100644
index 00000000..dcfd86c3
--- /dev/null
+++ b/speechx/cmake/external/glog.cmake
@@ -0,0 +1,8 @@
+include(FetchContent)
+FetchContent_Declare(
+ glog
+ URL https://github.com/google/glog/archive/v0.4.0.zip
+ URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
+)
+FetchContent_MakeAvailable(glog)
+include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
diff --git a/speechx/cmake/external/gtest.cmake b/speechx/cmake/external/gtest.cmake
new file mode 100644
index 00000000..7fe397fc
--- /dev/null
+++ b/speechx/cmake/external/gtest.cmake
@@ -0,0 +1,9 @@
+include(FetchContent)
+FetchContent_Declare(
+ gtest
+ URL https://github.com/google/googletest/archive/release-1.10.0.zip
+ URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
+)
+FetchContent_MakeAvailable(gtest)
+
+include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
\ No newline at end of file
diff --git a/speechx/cmake/external/kenlm.cmake b/speechx/cmake/external/kenlm.cmake
new file mode 100644
index 00000000..17c76c3f
--- /dev/null
+++ b/speechx/cmake/external/kenlm.cmake
@@ -0,0 +1,10 @@
+include(FetchContent)
+FetchContent_Declare(
+ kenlm
+ GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
+ GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
+)
+# https://github.com/kpu/kenlm/blob/master/cmake/modules/FindEigen3.cmake
+set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
+FetchContent_MakeAvailable(kenlm)
+include_directories(${kenlm_SOURCE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/libsndfile.cmake b/speechx/cmake/external/libsndfile.cmake
new file mode 100644
index 00000000..52d64bac
--- /dev/null
+++ b/speechx/cmake/external/libsndfile.cmake
@@ -0,0 +1,56 @@
+include(FetchContent)
+
+# https://github.com/pongasoft/vst-sam-spl-64/blob/master/libsndfile.cmake
+# https://github.com/popojan/goban/blob/master/CMakeLists.txt#L38
+# https://github.com/ddiakopoulos/libnyquist/blob/master/CMakeLists.txt
+
+if(LIBSNDFILE_ROOT_DIR)
+ # instructs FetchContent to not download or update but use the location instead
+ set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE ${LIBSNDFILE_ROOT_DIR})
+else()
+ set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE "")
+endif()
+
+set(LIBSNDFILE_GIT_REPO "https://github.com/libsndfile/libsndfile.git" CACHE STRING "libsndfile git repository url" FORCE)
+set(LIBSNDFILE_GIT_TAG 1.0.31 CACHE STRING "libsndfile git tag" FORCE)
+
+FetchContent_Declare(libsndfile
+ GIT_REPOSITORY ${LIBSNDFILE_GIT_REPO}
+ GIT_TAG ${LIBSNDFILE_GIT_TAG}
+ GIT_CONFIG advice.detachedHead=false
+# GIT_SHALLOW true
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+ )
+
+FetchContent_GetProperties(libsndfile)
+if(NOT libsndfile_POPULATED)
+ if(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE)
+ message(STATUS "Using libsndfile from local ${FETCHCONTENT_SOURCE_DIR_LIBSNDFILE}")
+ else()
+ message(STATUS "Fetching libsndfile ${LIBSNDFILE_GIT_REPO}/tree/${LIBSNDFILE_GIT_TAG}")
+ endif()
+ FetchContent_Populate(libsndfile)
+endif()
+
+set(LIBSNDFILE_ROOT_DIR ${libsndfile_SOURCE_DIR})
+set(LIBSNDFILE_INCLUDE_DIR "${libsndfile_BINARY_DIR}/src")
+
+function(libsndfile_build)
+ option(BUILD_PROGRAMS "Build programs" OFF)
+ option(BUILD_EXAMPLES "Build examples" OFF)
+ option(BUILD_TESTING "Build examples" OFF)
+ option(ENABLE_CPACK "Enable CPack support" OFF)
+ option(ENABLE_PACKAGE_CONFIG "Generate and install package config file" OFF)
+ option(BUILD_REGTEST "Build regtest" OFF)
+ # finally we include libsndfile itself
+ add_subdirectory(${libsndfile_SOURCE_DIR} ${libsndfile_BINARY_DIR} EXCLUDE_FROM_ALL)
+ # copying .hh for c++ support
+ #file(COPY "${libsndfile_SOURCE_DIR}/src/sndfile.hh" DESTINATION ${LIBSNDFILE_INCLUDE_DIR})
+endfunction()
+
+libsndfile_build()
+
+include_directories(${LIBSNDFILE_INCLUDE_DIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/openblas.cmake b/speechx/cmake/external/openblas.cmake
new file mode 100644
index 00000000..3c202f7f
--- /dev/null
+++ b/speechx/cmake/external/openblas.cmake
@@ -0,0 +1,37 @@
+include(FetchContent)
+
+set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
+set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix)
+
+# ######################################################################################################################
+# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
+# ######################################################################################################################
+enable_language(Fortran)
+#TODO: switch to CPM
+include(GNUInstallDirs)
+ExternalProject_Add(
+ OPENBLAS
+ GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
+ GIT_TAG v0.3.10
+ GIT_SHALLOW YES
+ PREFIX ${OpenBLAS_PREFIX}
+ SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
+ CMAKE_GENERATOR "Unix Makefiles")
+
+
+# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
+ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
+set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
+add_library(openblas STATIC IMPORTED)
+add_dependencies(openblas OPENBLAS)
+set_target_properties(openblas PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES Fortran)
+# ${CMAKE_INSTALL_LIBDIR} lib
+set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a)
+
+
+# https://cmake.org/cmake/help/latest/command/install.html?highlight=cmake_install_libdir#installing-targets
+# ${CMAKE_INSTALL_LIBDIR} lib
+# ${CMAKE_INSTALL_INCLUDEDIR} include
+link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
+include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
\ No newline at end of file
diff --git a/speechx/cmake/external/openfst.cmake b/speechx/cmake/external/openfst.cmake
new file mode 100644
index 00000000..07abb18e
--- /dev/null
+++ b/speechx/cmake/external/openfst.cmake
@@ -0,0 +1,19 @@
+include(FetchContent)
+set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
+set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
+
+ExternalProject_Add(openfst
+ URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
+ URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
+# #PREFIX ${openfst_PREFIX_DIR}
+# SOURCE_DIR ${openfst_SOURCE_DIR}
+# BINARY_DIR ${openfst_BINARY_DIR}
+ CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
+ "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
+ "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
+ "LIBS=-lgflags_nothreads -lglog -lpthread"
+ COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
+ BUILD_COMMAND make -j 4
+)
+link_directories(${openfst_PREFIX_DIR}/lib)
+include_directories(${openfst_PREFIX_DIR}/include)
diff --git a/speechx/examples/.gitignore b/speechx/examples/.gitignore
new file mode 100644
index 00000000..b7075fa5
--- /dev/null
+++ b/speechx/examples/.gitignore
@@ -0,0 +1,2 @@
+*.ark
+paddle_asr_model/
diff --git a/speechx/examples/.gitkeep b/speechx/examples/.gitkeep
deleted file mode 100644
index e69de29b..00000000
diff --git a/speechx/examples/CMakeLists.txt b/speechx/examples/CMakeLists.txt
new file mode 100644
index 00000000..ef0a72b8
--- /dev/null
+++ b/speechx/examples/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_subdirectory(feat)
+add_subdirectory(nnet)
+add_subdirectory(decoder)
diff --git a/speechx/examples/README.md b/speechx/examples/README.md
new file mode 100644
index 00000000..941c4272
--- /dev/null
+++ b/speechx/examples/README.md
@@ -0,0 +1,16 @@
+# Examples
+
+* decoder - online decoder to work as offline
+* feat - mfcc, linear
+* nnet - ds2 nn
+
+## How to run
+
+`run.sh` is the entry point.
+
+Example to play `decoder`:
+
+```
+pushd decoder
+bash run.sh
+```
diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt
new file mode 100644
index 00000000..4bd5c6cf
--- /dev/null
+++ b/speechx/examples/decoder/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
+target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline_decoder_main.cc
new file mode 100644
index 00000000..44127c73
--- /dev/null
+++ b/speechx/examples/decoder/offline_decoder_main.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "base/flags.h"
+#include "base/log.h"
+#include "decoder/ctc_beam_search_decoder.h"
+#include "frontend/raw_audio.h"
+#include "kaldi/util/table-types.h"
+#include "nnet/decodable.h"
+#include "nnet/paddle_nnet.h"
+
+DEFINE_string(feature_respecifier, "", "test feature rspecifier");
+DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
+DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
+DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
+DEFINE_string(lm_path, "lm.klm", "language model");
+
+
+using kaldi::BaseFloat;
+using kaldi::Matrix;
+using std::vector;
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialBaseFloatMatrixReader feature_reader(
+ FLAGS_feature_respecifier);
+ std::string model_graph = FLAGS_model_path;
+ std::string model_params = FLAGS_param_path;
+ std::string dict_file = FLAGS_dict_file;
+ std::string lm_path = FLAGS_lm_path;
+
+ int32 num_done = 0, num_err = 0;
+
+ ppspeech::CTCBeamSearchOptions opts;
+ opts.dict_file = dict_file;
+ opts.lm_path = lm_path;
+ ppspeech::CTCBeamSearch decoder(opts);
+
+ ppspeech::ModelOptions model_opts;
+ model_opts.model_path = model_graph;
+ model_opts.params_path = model_params;
+ std::shared_ptr nnet(
+ new ppspeech::PaddleNnet(model_opts));
+ std::shared_ptr raw_data(
+ new ppspeech::RawDataCache());
+ std::shared_ptr decodable(
+ new ppspeech::Decodable(nnet, raw_data));
+
+ int32 chunk_size = 35;
+ decoder.InitDecoder();
+
+ for (; !feature_reader.Done(); feature_reader.Next()) {
+ string utt = feature_reader.Key();
+ const kaldi::Matrix feature = feature_reader.Value();
+ raw_data->SetDim(feature.NumCols());
+ int32 row_idx = 0;
+ int32 num_chunks = feature.NumRows() / chunk_size;
+ for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
+ kaldi::Vector feature_chunk(chunk_size *
+ feature.NumCols());
+ for (int row_id = 0; row_id < chunk_size; ++row_id) {
+ kaldi::SubVector tmp(feature, row_idx);
+ kaldi::SubVector f_chunk_tmp(
+ feature_chunk.Data() + row_id * feature.NumCols(),
+ feature.NumCols());
+ f_chunk_tmp.CopyFromVec(tmp);
+ row_idx++;
+ }
+ raw_data->Accept(feature_chunk);
+ if (chunk_idx == num_chunks - 1) {
+ raw_data->SetFinished();
+ }
+ decoder.AdvanceDecode(decodable);
+ }
+ std::string result;
+ result = decoder.GetFinalBestPath();
+ KALDI_LOG << " the result of " << utt << " is " << result;
+ decodable->Reset();
+ decoder.Reset();
+ ++num_done;
+ }
+
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/examples/decoder/path.sh b/speechx/examples/decoder/path.sh
new file mode 100644
index 00000000..7b4b7545
--- /dev/null
+++ b/speechx/examples/decoder/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/decoder/run.sh b/speechx/examples/decoder/run.sh
new file mode 100755
index 00000000..fc5e9182
--- /dev/null
+++ b/speechx/examples/decoder/run.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+# 3. run feat
+linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
+
+# 4. run decoder
+offline_decoder_main \
+ --feature_respecifier=ark:$feat_wspecifier \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams \
+ --dict_file=$model_dir/vocab.txt \
+ --lm_path=$model_dir/avg_1.jit.klm
\ No newline at end of file
diff --git a/speechx/examples/decoder/valgrind.sh b/speechx/examples/decoder/valgrind.sh
new file mode 100755
index 00000000..14efe0ba
--- /dev/null
+++ b/speechx/examples/decoder/valgrind.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ offline_decoder_main \
+ --feature_respecifier=ark:$feat_wspecifier \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams \
+ --dict_file=$model_dir/vocab.txt \
+ --lm_path=$model_dir/avg_1.jit.klm
+
diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt
new file mode 100644
index 00000000..b8f516af
--- /dev/null
+++ b/speechx/examples/feat/CMakeLists.txt
@@ -0,0 +1,10 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+
+add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
+target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(mfcc-test kaldi-mfcc)
+
+add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
+target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
\ No newline at end of file
diff --git a/speechx/examples/feat/feature-mfcc-test.cc b/speechx/examples/feat/feature-mfcc-test.cc
new file mode 100644
index 00000000..ae32aba9
--- /dev/null
+++ b/speechx/examples/feat/feature-mfcc-test.cc
@@ -0,0 +1,720 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// feat/feature-mfcc-test.cc
+
+// Copyright 2009-2011 Karel Vesely; Petr Motlicek
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include
+
+#include "base/kaldi-math.h"
+#include "feat/feature-mfcc.h"
+#include "feat/wave-reader.h"
+#include "matrix/kaldi-matrix-inl.h"
+
+using namespace kaldi;
+
+
+static void UnitTestReadWave() {
+ std::cout << "=== UnitTestReadWave() ===\n";
+
+ Vector v, v2;
+
+ std::cout << "<<<=== Reading waveform\n";
+
+ {
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ const Matrix data(wave.Data());
+ KALDI_ASSERT(data.NumRows() == 1);
+ v.Resize(data.NumCols());
+ v.CopyFromVec(data.Row(0));
+ }
+
+ std::cout
+ << "<<<=== Reading Vector waveform, prepared by matlab\n";
+ std::ifstream input("test_data/test_matlab.ascii");
+ KALDI_ASSERT(input.good());
+ v2.Read(input, false);
+ input.close();
+
+ std::cout
+ << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
+ KALDI_ASSERT(v.Dim() == v2.Dim());
+ for (int32 i = 0; i < v.Dim(); i++) {
+ KALDI_ASSERT(v(i) == v2(i));
+ }
+ std::cout << "<<<=== Comparing done\n";
+
+ // std::cout << "== The Waveform Samples == \n";
+ // std::cout << v;
+
+ std::cout << "Test passed :)\n\n";
+}
+
+
+/**
+ */
+static void UnitTestSimple() {
+ std::cout << "=== UnitTestSimple() ===\n";
+
+ Vector v(100000);
+ Matrix m;
+
+ // init with noise
+ for (int32 i = 0; i < v.Dim(); i++) {
+ v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
+ }
+
+ std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
+ // the parametrization object
+ MfccOptions op;
+ // trying to have same opts as baseline.
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "rectangular";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+
+ Mfcc mfcc(op);
+ // use default parameters
+
+ // compute mfccs.
+ mfcc.Compute(v, 1.0, &m);
+
+ // possibly dump
+ // std::cout << "== Output features == \n" << m;
+ std::cout << "Test passed :)\n\n";
+}
+
+
+static void UnitTestHTKCompare1() {
+ std::cout << "=== UnitTestHTKCompare1() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.1",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+ op.use_energy = false; // C0 not energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (i_old != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.1",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.1");
+}
+
+
+static void UnitTestHTKCompare2() {
+ std::cout << "=== UnitTestHTKCompare2() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.2",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.htk_mode = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (i_old != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.2",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.2");
+}
+
+
+static void UnitTestHTKCompare3() {
+ std::cout << "=== UnitTestHTKCompare3() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.3",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.low_freq = 20.0;
+ // op.mel_opts.debug_mel = true;
+ op.mel_opts.htk_mode = true;
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.3",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.3");
+}
+
+
+static void UnitTestHTKCompare4() {
+ std::cout << "=== UnitTestHTKCompare4() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.4",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.low_freq = 0.0;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.htk_mode = true;
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.4",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.4");
+}
+
+
+static void UnitTestHTKCompare5() {
+ std::cout << "=== UnitTestHTKCompare5() ===\n";
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.5",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.htk_compat = true;
+ op.use_energy = true; // Use energy.
+ op.mel_opts.low_freq = 0.0;
+ op.mel_opts.vtln_low = 100.0;
+ op.mel_opts.vtln_high = 7500.0;
+ op.mel_opts.htk_mode = true;
+
+ BaseFloat vtln_warp =
+ 1.1; // our approach identical to htk for warp factor >1,
+ // differs slightly for higher mel bins if warp_factor <0.9
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.5",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.5");
+}
+
+static void UnitTestHTKCompare6() {
+ std::cout << "=== UnitTestHTKCompare6() ===\n";
+
+
+ std::ifstream is("test_data/test.wav", std::ios_base::binary);
+ WaveData wave;
+ wave.Read(is);
+ KALDI_ASSERT(wave.Data().NumRows() == 1);
+ SubVector waveform(wave.Data(), 0);
+
+ // read the HTK features
+ Matrix htk_features;
+ {
+ std::ifstream is("test_data/test.wav.fea_htk.6",
+ std::ios::in | std::ios_base::binary);
+ bool ans = ReadHtk(is, &htk_features, 0);
+ KALDI_ASSERT(ans);
+ }
+
+ // use mfcc with default configuration...
+ MfccOptions op;
+ op.frame_opts.dither = 0.0;
+ op.frame_opts.preemph_coeff = 0.97;
+ op.frame_opts.window_type = "hamming";
+ op.frame_opts.remove_dc_offset = false;
+ op.frame_opts.round_to_power_of_two = true;
+ op.mel_opts.num_bins = 24;
+ op.mel_opts.low_freq = 125.0;
+ op.mel_opts.high_freq = 7800.0;
+ op.htk_compat = true;
+ op.use_energy = false; // C0 not energy.
+
+ Mfcc mfcc(op);
+
+ // calculate kaldi features
+ Matrix kaldi_raw_features;
+ mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
+
+ DeltaFeaturesOptions delta_opts;
+ Matrix kaldi_features;
+ ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
+
+ // compare the results
+ bool passed = true;
+ int32 i_old = -1;
+ KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
+ KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
+ // Ignore ends-- we make slightly different choices than
+ // HTK about how to treat the deltas at the ends.
+ for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
+ for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
+ BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
+ if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
+ // print the non-matching data only once per-line
+ if (static_cast(i_old) != i) {
+ std::cout << "\n\n\n[HTK-row: " << i << "] "
+ << htk_features.Row(i) << "\n";
+ std::cout << "[Kaldi-row: " << i << "] "
+ << kaldi_features.Row(i) << "\n\n\n";
+ i_old = i;
+ }
+ // print indices of non-matching cells
+ std::cout << "[" << i << ", " << j << "]";
+ passed = false;
+ }
+ }
+ }
+ if (!passed) KALDI_ERR << "Test failed";
+
+ // write the htk features for later inspection
+ HtkHeader header = {
+ kaldi_features.NumRows(),
+ 100000, // 10ms
+ static_cast(sizeof(float) * kaldi_features.NumCols()),
+ 021406 // MFCC_D_A_0
+ };
+ {
+ std::ofstream os("tmp.test.wav.fea_kaldi.6",
+ std::ios::out | std::ios::binary);
+ WriteHtk(os, kaldi_features, header);
+ }
+
+ std::cout << "Test passed :)\n\n";
+
+ unlink("tmp.test.wav.fea_kaldi.6");
+}
+
+void UnitTestVtln() {
+ // Test the function VtlnWarpFreq.
+ BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
+ vtln_high_cutoff = 7400;
+
+ for (size_t i = 0; i < 100; i++) {
+ BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq),
+ freq / warp_factor);
+
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ low_freq),
+ low_freq);
+ AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ high_freq),
+ high_freq);
+ BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
+ freq3 = freq2 +
+ (high_freq - freq2) * RandUniform(); // freq3>=freq2
+ BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq2);
+ BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
+ vtln_high_cutoff,
+ low_freq,
+ high_freq,
+ warp_factor,
+ freq3);
+ KALDI_ASSERT(w3 >= w2); // increasing function.
+ BaseFloat w3dash = MelBanks::VtlnWarpFreq(
+ vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
+ AssertEqual(w3dash, freq3);
+ }
+}
+
+static void UnitTestFeat() {
+ UnitTestVtln();
+ UnitTestReadWave();
+ UnitTestSimple();
+ UnitTestHTKCompare1();
+ UnitTestHTKCompare2();
+ // commenting out this one as it doesn't compare right now I normalized
+ // the way the FFT bins are treated (removed offset of 0.5)... this seems
+ // to relate to the way frequency zero behaves.
+ UnitTestHTKCompare3();
+ UnitTestHTKCompare4();
+ UnitTestHTKCompare5();
+ UnitTestHTKCompare6();
+ std::cout << "Tests succeeded.\n";
+}
+
+
+int main() {
+ try {
+ for (int i = 0; i < 5; i++) UnitTestFeat();
+ std::cout << "Tests succeeded.\n";
+ return 0;
+ } catch (const std::exception &e) {
+ std::cerr << e.what();
+ return 1;
+ }
+}
diff --git a/speechx/examples/feat/linear_spectrogram_main.cc b/speechx/examples/feat/linear_spectrogram_main.cc
new file mode 100644
index 00000000..9ed4d6f9
--- /dev/null
+++ b/speechx/examples/feat/linear_spectrogram_main.cc
@@ -0,0 +1,248 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// todo refactor, repalce with gtest
+
+#include "frontend/linear_spectrogram.h"
+#include "base/flags.h"
+#include "base/log.h"
+#include "frontend/feature_cache.h"
+#include "frontend/feature_extractor_interface.h"
+#include "frontend/normalizer.h"
+#include "frontend/raw_audio.h"
+#include "kaldi/feat/wave-reader.h"
+#include "kaldi/util/kaldi-io.h"
+#include "kaldi/util/table-types.h"
+
+DEFINE_string(wav_rspecifier, "", "test wav scp path");
+DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
+DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
+
+
+std::vector mean_{
+ -13730251.531853663, -12982852.199316509, -13673844.299583456,
+ -13089406.559646806, -12673095.524938712, -12823859.223276224,
+ -13590267.158903603, -14257618.467152044, -14374605.116185192,
+ -14490009.21822485, -14849827.158924166, -15354435.470563512,
+ -15834149.206532761, -16172971.985514281, -16348740.496746974,
+ -16423536.699409386, -16556246.263649225, -16744088.772748645,
+ -16916184.08510357, -17054034.840031497, -17165612.509455364,
+ -17255955.470915023, -17322572.527648456, -17408943.862033736,
+ -17521554.799865916, -17620623.254924215, -17699792.395918526,
+ -17723364.411134344, -17741483.4433254, -17747426.888704527,
+ -17733315.928209435, -17748780.160905756, -17808336.883775543,
+ -17895918.671983004, -18009812.59173023, -18098188.66548325,
+ -18195798.958462656, -18293617.62980999, -18397432.92077201,
+ -18505834.787318766, -18585451.8100908, -18652438.235649142,
+ -18700960.306275308, -18734944.58792185, -18737426.313365128,
+ -18735347.165987637, -18738813.444170244, -18737086.848890636,
+ -18731576.2474336, -18717405.44095871, -18703089.25545657,
+ -18691014.546456724, -18692460.568905357, -18702119.628629155,
+ -18727710.621126678, -18761582.72034647, -18806745.835547544,
+ -18850674.8692112, -18884431.510951452, -18919999.992506847,
+ -18939303.799078144, -18952946.273760635, -18980289.22996379,
+ -19011610.17803294, -19040948.61805145, -19061021.429847397,
+ -19112055.53768819, -19149667.414264943, -19201127.05091321,
+ -19270250.82564605, -19334606.883057203, -19390513.336589377,
+ -19444176.259208687, -19502755.000038862, -19544333.014549147,
+ -19612668.183176614, -19681902.19006569, -19771969.951249883,
+ -19873329.723376893, -19996752.59235844, -20110031.131400537,
+ -20231658.612529557, -20319378.894054495, -20378534.45718066,
+ -20413332.089584175, -20438147.844177883, -20443710.248040095,
+ -20465457.02238927, -20488610.969337028, -20516295.16424432,
+ -20541423.795738827, -20553192.874953747, -20573605.50701977,
+ -20577871.61936797, -20571807.008916274, -20556242.38912231,
+ -20542199.30819195, -20521239.063551214, -20519150.80004532,
+ -20527204.80248933, -20536933.769257784, -20543470.522332076,
+ -20549700.089992985, -20551525.24958494, -20554873.406493705,
+ -20564277.65794227, -20572211.740052115, -20574305.69550465,
+ -20575494.450104576, -20567092.577932164, -20549302.929608088,
+ -20545445.11878376, -20546625.326603737, -20549190.03499401,
+ -20554824.947828256, -20568341.378989458, -20577582.331383612,
+ -20577980.519402675, -20566603.03458152, -20560131.592262644,
+ -20552166.469060015, -20549063.06763577, -20544490.562339947,
+ -20539817.82346569, -20528747.715731595, -20518026.24576161,
+ -20510977.844974525, -20506874.36087992, -20506731.11977665,
+ -20510482.133420516, -20507760.92101862, -20494644.834457114,
+ -20480107.89304893, -20461312.091867123, -20442941.75080173,
+ -20426123.02834838, -20424607.675283, -20426810.369107097,
+ -20434024.50097819, -20437404.75544205, -20447688.63916367,
+ -20460893.335563846, -20482922.735127095, -20503610.119434915,
+ -20527062.76448319, -20557830.035128627, -20593274.72068722,
+ -20632528.452965066, -20673637.471334763, -20733106.97143075,
+ -20842921.0447562, -21054357.83621519, -21416569.534189366,
+ -21978460.272811692, -22753170.052172784, -23671344.10563395,
+ -24613499.293358143, -25406477.12230188, -25884377.82156489,
+ -26049040.62791664, -26996879.104431007};
+std::vector variance_{
+ 213747175.10846674, 188395815.34302503, 212706429.10966414,
+ 199109025.81461075, 189235901.23864496, 194901336.53253657,
+ 217481594.29306737, 238689869.12327808, 243977501.24115244,
+ 248479623.6431067, 259766741.47116545, 275516766.7790273,
+ 291271202.3691234, 302693239.8220509, 308627358.3997694,
+ 311143911.38788426, 315446105.07731867, 321705430.9341829,
+ 327458907.4659941, 332245072.43223983, 336251717.5935284,
+ 339694069.7639722, 342188204.4322228, 345587110.31313115,
+ 349903086.2875232, 353660214.20643026, 356700344.5270885,
+ 357665362.3529641, 358493352.05658793, 358857951.620328,
+ 358375239.52774596, 358899733.6342954, 361051818.3511561,
+ 364361716.05025816, 368750322.3771452, 372047800.6462831,
+ 375655861.1349018, 379358519.1980013, 383327605.3935181,
+ 387458599.282341, 390434692.3406868, 392994486.35057056,
+ 394874418.04603153, 396230525.79763395, 396365592.0414835,
+ 396334819.8242737, 396488353.19250053, 396438877.00744957,
+ 396197980.4459586, 395590921.6672991, 395001107.62072515,
+ 394528291.7318225, 394593110.424006, 395018405.59353715,
+ 396110577.5415993, 397506704.0371068, 399400197.4657644,
+ 401243568.2468382, 402687134.7805103, 404136047.2872507,
+ 404883170.001883, 405522253.219517, 406660365.3626476,
+ 407919346.0991902, 409045348.5384909, 409759588.7889818,
+ 411974821.8564483, 413489718.78201455, 415535392.56684107,
+ 418466481.97674364, 421104678.35678065, 423405392.5200779,
+ 425550570.40798235, 427929423.9579701, 429585274.253478,
+ 432368493.55181056, 435193587.13513297, 438886855.20476013,
+ 443058876.8633751, 448181232.5093362, 452883835.6332396,
+ 458056721.77926534, 461816531.22735566, 464363620.1970998,
+ 465886343.5057493, 466928872.0651, 467180536.42647296,
+ 468111848.70714295, 469138695.3071312, 470378429.6930793,
+ 471517958.7132626, 472109050.4262365, 473087417.0177867,
+ 473381322.04648733, 473220195.85483915, 472666071.8998819,
+ 472124669.87879956, 471298571.411737, 471251033.2902761,
+ 471672676.43128747, 472177147.2193172, 472572361.7711908,
+ 472968783.7751127, 473156295.4164052, 473398034.82676554,
+ 473897703.5203811, 474328271.33112127, 474452670.98002136,
+ 474549003.99284613, 474252887.13567275, 473557462.909069,
+ 473483385.85193115, 473609738.04855174, 473746944.82085115,
+ 474016729.91696435, 474617321.94138587, 475045097.237122,
+ 475125402.586558, 474664112.9824912, 474426247.5800283,
+ 474104075.42796475, 473978219.7273978, 473773171.7798875,
+ 473578534.69508696, 473102924.16904145, 472651240.5232615,
+ 472374383.1810912, 472209479.6956096, 472202298.8921673,
+ 472370090.76781124, 472220933.99374026, 471625467.37106377,
+ 470994646.51883453, 470182428.9637543, 469348211.5939578,
+ 468570387.4467277, 468540442.7225135, 468672018.90414184,
+ 468994346.9533251, 469138757.58201426, 469553915.95710236,
+ 470134523.38582784, 471082421.62055486, 471962316.51804745,
+ 472939745.1708408, 474250621.5944825, 475773933.43199486,
+ 477465399.71087736, 479218782.61382693, 481752299.7930922,
+ 486608947.8984568, 496119403.2067917, 512730085.5704984,
+ 539048915.2641417, 576285298.3548826, 621610270.2240586,
+ 669308196.4436442, 710656993.5957186, 736344437.3725077,
+ 745481288.0241544, 801121432.9925804};
+int count_ = 912592;
+
+void WriteMatrix() {
+ kaldi::Matrix cmvn_stats(2, mean_.size() + 1);
+ for (size_t idx = 0; idx < mean_.size(); ++idx) {
+ cmvn_stats(0, idx) = mean_[idx];
+ cmvn_stats(1, idx) = variance_[idx];
+ }
+ cmvn_stats(0, mean_.size()) = count_;
+ kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
+}
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, false);
+ google::InitGoogleLogging(argv[0]);
+
+ kaldi::SequentialTableReader wav_reader(
+ FLAGS_wav_rspecifier);
+ kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
+ WriteMatrix();
+
+ // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
+ // window -->linear_spectrogram --> cmvn
+ int32 num_done = 0, num_err = 0;
+ // std::unique_ptr data_source(new
+ // ppspeech::RawDataCache());
+ std::unique_ptr data_source(
+ new ppspeech::RawAudioCache());
+
+ ppspeech::LinearSpectrogramOptions opt;
+ opt.frame_opts.frame_length_ms = 20;
+ opt.frame_opts.frame_shift_ms = 10;
+ ppspeech::DecibelNormalizerOptions db_norm_opt;
+ std::unique_ptr base_feature_extractor(
+ new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
+
+ std::unique_ptr linear_spectrogram(
+ new ppspeech::LinearSpectrogram(opt,
+ std::move(base_feature_extractor)));
+
+ std::unique_ptr cmvn(
+ new ppspeech::CMVN(FLAGS_cmvn_write_path,
+ std::move(linear_spectrogram)));
+
+ ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
+
+ float streaming_chunk = 0.36;
+ int sample_rate = 16000;
+ int chunk_sample_size = streaming_chunk * sample_rate;
+
+ for (; !wav_reader.Done(); wav_reader.Next()) {
+ std::string utt = wav_reader.Key();
+ const kaldi::WaveData& wave_data = wav_reader.Value();
+
+ int32 this_channel = 0;
+ kaldi::SubVector waveform(wave_data.Data(),
+ this_channel);
+ int tot_samples = waveform.Dim();
+ int sample_offset = 0;
+ std::vector> feats;
+ int feature_rows = 0;
+ while (sample_offset < tot_samples) {
+ int cur_chunk_size =
+ std::min(chunk_sample_size, tot_samples - sample_offset);
+
+ kaldi::Vector wav_chunk(cur_chunk_size);
+ for (int i = 0; i < cur_chunk_size; ++i) {
+ wav_chunk(i) = waveform(sample_offset + i);
+ }
+ kaldi::Vector features;
+ feature_cache.Accept(wav_chunk);
+ if (cur_chunk_size < chunk_sample_size) {
+ feature_cache.SetFinished();
+ }
+ feature_cache.Read(&features);
+ if (features.Dim() == 0) break;
+
+ feats.push_back(features);
+ sample_offset += cur_chunk_size;
+ feature_rows += features.Dim() / feature_cache.Dim();
+ }
+
+ int cur_idx = 0;
+ kaldi::Matrix features(feature_rows,
+ feature_cache.Dim());
+ for (auto feat : feats) {
+ int num_rows = feat.Dim() / feature_cache.Dim();
+ for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
+ for (size_t col_idx = 0; col_idx < feature_cache.Dim();
+ ++col_idx) {
+ features(cur_idx, col_idx) =
+ feat(row_idx * feature_cache.Dim() + col_idx);
+ }
+ ++cur_idx;
+ }
+ }
+ feat_writer.Write(utt, features);
+
+ if (num_done % 50 == 0 && num_done != 0)
+ KALDI_VLOG(2) << "Processed " << num_done << " utterances";
+ num_done++;
+ }
+ KALDI_LOG << "Done " << num_done << " utterances, " << num_err
+ << " with errors.";
+ return (num_done != 0 ? 0 : 1);
+}
diff --git a/speechx/examples/feat/path.sh b/speechx/examples/feat/path.sh
new file mode 100644
index 00000000..8ab7ee29
--- /dev/null
+++ b/speechx/examples/feat/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/feat/run.sh b/speechx/examples/feat/run.sh
new file mode 100755
index 00000000..bd21bd7f
--- /dev/null
+++ b/speechx/examples/feat/run.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+set +x
+set -e
+
+. ./path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+# 3. run feat
+linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
diff --git a/speechx/examples/feat/valgrind.sh b/speechx/examples/feat/valgrind.sh
new file mode 100755
index 00000000..f8aab63f
--- /dev/null
+++ b/speechx/examples/feat/valgrind.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+feat_wspecifier=./feats.ark
+cmvn=./cmvn.ark
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ linear_spectrogram_main \
+ --wav_rspecifier=scp:$model_dir/wav.scp \
+ --feature_wspecifier=ark,t:$feat_wspecifier \
+ --cmvn_write_path=$cmvn
+
diff --git a/speechx/examples/nnet/CMakeLists.txt b/speechx/examples/nnet/CMakeLists.txt
new file mode 100644
index 00000000..20f4008c
--- /dev/null
+++ b/speechx/examples/nnet/CMakeLists.txt
@@ -0,0 +1,5 @@
+cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
+
+add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
+target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
+target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
\ No newline at end of file
diff --git a/speechx/examples/nnet/path.sh b/speechx/examples/nnet/path.sh
new file mode 100644
index 00000000..f70e70ee
--- /dev/null
+++ b/speechx/examples/nnet/path.sh
@@ -0,0 +1,14 @@
+# This contains the locations of binarys build required for running the examples.
+
+SPEECHX_ROOT=$PWD/../..
+SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
+
+SPEECHX_TOOLS=$SPEECHX_ROOT/tools
+TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
+
+[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
+
+export LC_AL=C
+
+SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
+export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
diff --git a/speechx/examples/nnet/pp-model-test.cc b/speechx/examples/nnet/pp-model-test.cc
new file mode 100644
index 00000000..2db354a7
--- /dev/null
+++ b/speechx/examples/nnet/pp-model-test.cc
@@ -0,0 +1,193 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "paddle_inference_api.h"
+
+using std::cout;
+using std::endl;
+
+DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
+DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
+
+
+void produce_data(std::vector>* data);
+void model_forward_test();
+
+void produce_data(std::vector>* data) {
+ int chunk_size = 35; // chunk_size in frame
+ int col_size = 161; // feat dim
+ cout << "chunk size: " << chunk_size << endl;
+ cout << "feat dim: " << col_size << endl;
+
+ data->reserve(chunk_size);
+ data->back().reserve(col_size);
+ for (int row = 0; row < chunk_size; ++row) {
+ data->push_back(std::vector());
+ for (int col_idx = 0; col_idx < col_size; ++col_idx) {
+ data->back().push_back(0.201);
+ }
+ }
+}
+
+void model_forward_test() {
+ std::cout << "1. read the data" << std::endl;
+ std::vector> feats;
+ produce_data(&feats);
+
+ std::cout << "2. load the model" << std::endl;
+ ;
+ std::string model_graph = FLAGS_model_path;
+ std::string model_params = FLAGS_param_path;
+ cout << "model path: " << model_graph << endl;
+ cout << "model param path : " << model_params << endl;
+
+ paddle_infer::Config config;
+ config.SetModel(model_graph, model_params);
+ config.SwitchIrOptim(false);
+ cout << "SwitchIrOptim: " << false << endl;
+ config.DisableFCPadding();
+ cout << "DisableFCPadding: " << endl;
+ auto predictor = paddle_infer::CreatePredictor(config);
+
+ std::cout << "3. feat shape, row=" << feats.size()
+ << ",col=" << feats[0].size() << std::endl;
+ std::vector pp_input_mat;
+ for (const auto& item : feats) {
+ pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
+ }
+
+ std::cout << "4. fead the data to model" << std::endl;
+ int row = feats.size();
+ int col = feats[0].size();
+ std::vector input_names = predictor->GetInputNames();
+ std::vector output_names = predictor->GetOutputNames();
+ for (auto name : input_names) {
+ cout << "model input names: " << name << endl;
+ }
+ for (auto name : output_names) {
+ cout << "model output names: " << name << endl;
+ }
+
+ // input
+ std::unique_ptr input_tensor =
+ predictor->GetInputHandle(input_names[0]);
+ std::vector INPUT_SHAPE = {1, row, col};
+ input_tensor->Reshape(INPUT_SHAPE);
+ input_tensor->CopyFromCpu(pp_input_mat.data());
+
+ // input length
+ std::unique_ptr input_len =
+ predictor->GetInputHandle(input_names[1]);
+ std::vector input_len_size = {1};
+ input_len->Reshape(input_len_size);
+ std::vector audio_len;
+ audio_len.push_back(row);
+ input_len->CopyFromCpu(audio_len.data());
+
+ // state_h
+ std::unique_ptr chunk_state_h_box =
+ predictor->GetInputHandle(input_names[2]);
+ std::vector chunk_state_h_box_shape = {3, 1, 1024};
+ chunk_state_h_box->Reshape(chunk_state_h_box_shape);
+ int chunk_state_h_box_size =
+ std::accumulate(chunk_state_h_box_shape.begin(),
+ chunk_state_h_box_shape.end(),
+ 1,
+ std::multiplies());
+ std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
+ chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
+
+ // state_c
+ std::unique_ptr chunk_state_c_box =
+ predictor->GetInputHandle(input_names[3]);
+ std::vector chunk_state_c_box_shape = {3, 1, 1024};
+ chunk_state_c_box->Reshape(chunk_state_c_box_shape);
+ int chunk_state_c_box_size =
+ std::accumulate(chunk_state_c_box_shape.begin(),
+ chunk_state_c_box_shape.end(),
+ 1,
+ std::multiplies());
+ std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
+ chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
+
+ // run
+ bool success = predictor->Run();
+
+ // state_h out
+ std::unique_ptr h_out =
+ predictor->GetOutputHandle(output_names[2]);
+ std::vector h_out_shape = h_out->shape();
+ int h_out_size = std::accumulate(
+ h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies());
+ std::vector h_out_data(h_out_size);
+ h_out->CopyToCpu(h_out_data.data());
+
+ // stage_c out
+ std::unique_ptr c_out =
+ predictor->GetOutputHandle(output_names[3]);
+ std::vector c_out_shape = c_out->shape();
+ int c_out_size = std::accumulate(
+ c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies());
+ std::vector c_out_data(c_out_size);
+ c_out->CopyToCpu(c_out_data.data());
+
+ // output tensor
+ std::unique_ptr output_tensor =
+ predictor->GetOutputHandle(output_names[0]);
+ std::vector output_shape = output_tensor->shape();
+ std::vector output_probs;
+ int output_size = std::accumulate(
+ output_shape.begin(), output_shape.end(), 1, std::multiplies());
+ output_probs.resize(output_size);
+ output_tensor->CopyToCpu(output_probs.data());
+ row = output_shape[1];
+ col = output_shape[2];
+
+ // probs
+ std::vector> probs;
+ probs.reserve(row);
+ for (int i = 0; i < row; i++) {
+ probs.push_back(std::vector());
+ probs.back().reserve(col);
+
+ for (int j = 0; j < col; j++) {
+ probs.back().push_back(output_probs[i * col + j]);
+ }
+ }
+
+ std::vector> log_feat = probs;
+ std::cout << "probs, row: " << log_feat.size()
+ << " col: " << log_feat[0].size() << std::endl;
+ for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
+ for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
+ ++col_idx) {
+ std::cout << log_feat[row_idx][col_idx] << " ";
+ }
+ std::cout << std::endl;
+ }
+}
+
+int main(int argc, char* argv[]) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+ model_forward_test();
+ return 0;
+}
diff --git a/speechx/examples/nnet/run.sh b/speechx/examples/nnet/run.sh
new file mode 100755
index 00000000..4d67d198
--- /dev/null
+++ b/speechx/examples/nnet/run.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+set +x
+set -e
+
+. path.sh
+
+# 1. compile
+if [ ! -d ${SPEECHX_EXAMPLES} ]; then
+ pushd ${SPEECHX_ROOT}
+ bash build.sh
+ popd
+fi
+
+# 2. download model
+if [ ! -d ../paddle_asr_model ]; then
+ wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
+ tar xzfv paddle_asr_model.tar.gz
+ mv ./paddle_asr_model ../
+ # produce wav scp
+ echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
+fi
+
+model_dir=../paddle_asr_model
+
+# 4. run decoder
+pp-model-test \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams
+
diff --git a/speechx/examples/nnet/valgrind.sh b/speechx/examples/nnet/valgrind.sh
new file mode 100755
index 00000000..2a08c608
--- /dev/null
+++ b/speechx/examples/nnet/valgrind.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# this script is for memory check, so please run ./run.sh first.
+
+set +x
+set -e
+
+. ./path.sh
+
+if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
+ echo "please install valgrind in the speechx tools dir.\n"
+ exit 1
+fi
+
+model_dir=../paddle_asr_model
+
+valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
+ pp-model-test \
+ --model_path=$model_dir/avg_1.jit.pdmodel \
+ --param_path=$model_dir/avg_1.jit.pdparams
\ No newline at end of file
diff --git a/speechx/patch/CPPLINT.cfg b/speechx/patch/CPPLINT.cfg
new file mode 100644
index 00000000..51ff339c
--- /dev/null
+++ b/speechx/patch/CPPLINT.cfg
@@ -0,0 +1 @@
+exclude_files=.*
diff --git a/speechx/patch/openfst/src/include/fst/flags.h b/speechx/patch/openfst/src/include/fst/flags.h
new file mode 100644
index 00000000..b5ec8ff7
--- /dev/null
+++ b/speechx/patch/openfst/src/include/fst/flags.h
@@ -0,0 +1,228 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// See www.openfst.org for extensive documentation on this weighted
+// finite-state transducer library.
+//
+// Google-style flag handling declarations and inline definitions.
+
+#ifndef FST_LIB_FLAGS_H_
+#define FST_LIB_FLAGS_H_
+
+#include
+
+#include
+#include