commit
305bacdcf2
@ -0,0 +1,88 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.0
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.0.1
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.10
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
ports:
|
||||
- "19530:19530"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
mysql:
|
||||
container_name: audio-mysql
|
||||
image: mysql:5.7
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.11
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=123456
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||
ports:
|
||||
- "3306:3306"
|
||||
|
||||
webclient:
|
||||
container_name: audio-webclient
|
||||
image: qingen1/paddlespeech-audio-search-client:2.3
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.13
|
||||
environment:
|
||||
API_URL: 'http://127.0.0.1:8002'
|
||||
ports:
|
||||
- "8068:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
networks:
|
||||
app_net:
|
||||
driver: bridge
|
||||
ipam:
|
||||
driver: default
|
||||
config:
|
||||
- subnet: 172.16.23.0/24
|
||||
gateway: 172.16.23.1
|
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 50 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 84 KiB |
@ -0,0 +1,12 @@
|
||||
diskcache==5.2.1
|
||||
fastapi
|
||||
librosa==0.8.0
|
||||
numpy
|
||||
pydantic
|
||||
pymilvus==2.0.1
|
||||
pymysql
|
||||
python-multipart
|
||||
soundfile==0.10.3.post1
|
||||
starlette
|
||||
typing
|
||||
uvicorn
|
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
############### Milvus Configuration ###############
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||
|
||||
############### MySQL Configuration ###############
|
||||
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||
|
||||
############### Data Path ###############
|
||||
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||
|
||||
############### Number of Log Files ###############
|
||||
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audio_embedding(path):
|
||||
"""
|
||||
Use vpr_inference to generate embedding of audio
|
||||
"""
|
||||
try:
|
||||
RESAMPLE_RATE = 16000
|
||||
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||
|
||||
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||
# embedding = vpr.inference(audio)
|
||||
np.random.seed(hash(os.path.basename(path)) % 1000000)
|
||||
embedding = np.random.rand(1, 2048)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
embedding = embedding.tolist()[0]
|
||||
return embedding
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with embedding:{e}")
|
||||
return None
|
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from diskcache import Cache
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from milvus_helpers import MilvusHelper
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count
|
||||
from operations.drop import do_drop
|
||||
from operations.load import do_load
|
||||
from operations.search import do_search
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MODEL = None
|
||||
MILVUS_CLI = MilvusHelper()
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file
|
||||
try:
|
||||
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"upload audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/progress')
|
||||
def get_progress():
|
||||
# Get the progress of dealing with data
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
return f"current: {cache['current']}, total: {cache['total']}"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Upload data error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
Table: Optional[str] = None
|
||||
File: str
|
||||
|
||||
|
||||
@app.post('/audio/load')
|
||||
async def load_audios(item: Item):
|
||||
# Insert all the audio files under the file path to Milvus/MySQL
|
||||
try:
|
||||
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search')
|
||||
async def search_audio(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search/local')
|
||||
async def search_local_audio(request: Request,
|
||||
query_audio_path: str,
|
||||
table_name: str=None):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/audio/count')
|
||||
async def count_audio(table_name: str=None):
|
||||
# Returns the total number of vectors in the system
|
||||
try:
|
||||
num = do_count(table_name, MILVUS_CLI)
|
||||
LOGGER.info("Successfully count the number of data!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the collection of Milvus and MySQL
|
||||
try:
|
||||
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import METRIC_TYPE
|
||||
from config import MILVUS_HOST
|
||||
from config import MILVUS_PORT
|
||||
from config import VECTOR_DIMENSION
|
||||
from logs import LOGGER
|
||||
from pymilvus import Collection
|
||||
from pymilvus import CollectionSchema
|
||||
from pymilvus import connections
|
||||
from pymilvus import DataType
|
||||
from pymilvus import FieldSchema
|
||||
from pymilvus import utility
|
||||
|
||||
|
||||
class MilvusHelper:
|
||||
"""
|
||||
the basic operations of PyMilvus
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to Milvus server
|
||||
# 2. create a collection
|
||||
# 3. insert entities
|
||||
# 4. create index
|
||||
# 5. search
|
||||
# 6. delete a collection
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.collection = None
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
LOGGER.debug(
|
||||
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
try:
|
||||
if self.has_collection(collection_name):
|
||||
self.collection = Collection(name=collection_name)
|
||||
else:
|
||||
raise Exception(
|
||||
f"There is no collection named:{collection_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to set collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def has_collection(self, collection_name):
|
||||
# Return if Milvus has the collection
|
||||
try:
|
||||
return utility.has_collection(collection_name)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_collection(self, collection_name):
|
||||
# Create milvus collection if not exists
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
field1 = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
descrition="int64",
|
||||
is_primary=True,
|
||||
auto_id=True)
|
||||
field2 = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
descrition="speaker embeddings",
|
||||
dim=VECTOR_DIMENSION,
|
||||
is_primary=False)
|
||||
schema = CollectionSchema(
|
||||
fields=[field1, field2], description="embeddings info")
|
||||
self.collection = Collection(
|
||||
name=collection_name, schema=schema)
|
||||
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||
else:
|
||||
self.set_collection(collection_name)
|
||||
return "OK"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def insert(self, collection_name, vectors):
|
||||
# Batch insert vectors to milvus collection
|
||||
try:
|
||||
self.create_collection(collection_name)
|
||||
data = [vectors]
|
||||
self.set_collection(collection_name)
|
||||
mr = self.collection.insert(data)
|
||||
ids = mr.primary_keys
|
||||
self.collection.load()
|
||||
LOGGER.debug(
|
||||
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to insert data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_index(self, collection_name):
|
||||
# Create IVF_FLAT index on milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
default_index = {
|
||||
"index_type": "IVF_SQ8",
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nlist": 16384
|
||||
}
|
||||
}
|
||||
status = self.collection.create_index(
|
||||
field_name="embedding", index_params=default_index)
|
||||
if not status.code:
|
||||
LOGGER.debug(
|
||||
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||
)
|
||||
return status
|
||||
else:
|
||||
raise Exception(status.message)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create index: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_collection(self, collection_name):
|
||||
# Delete Milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
self.collection.drop()
|
||||
LOGGER.debug("Successfully drop collection!")
|
||||
return "ok"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to drop collection: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_vectors(self, collection_name, vectors, top_k):
|
||||
# Search vector in milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
search_params = {
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nprobe": 16
|
||||
}
|
||||
}
|
||||
res = self.collection.search(
|
||||
vectors,
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k)
|
||||
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def count(self, collection_name):
|
||||
# Get the number of milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
num = self.collection.num_entities
|
||||
LOGGER.debug(
|
||||
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||
)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import pymysql
|
||||
from config import MYSQL_DB
|
||||
from config import MYSQL_HOST
|
||||
from config import MYSQL_PORT
|
||||
from config import MYSQL_PWD
|
||||
from config import MYSQL_USER
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
class MySQLHelper():
|
||||
"""
|
||||
the basic operations of PyMySQL
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to MySQL server
|
||||
# 2. create a table
|
||||
# 3. insert data to table
|
||||
# 4. search by milvus ids
|
||||
# 5. delete table
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def test_connection(self):
|
||||
try:
|
||||
self.conn.ping()
|
||||
except Exception:
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def create_mysql_table(self, table_name):
|
||||
# Create mysql table if not exists
|
||||
self.test_connection()
|
||||
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def load_data_to_mysql(self, table_name, data):
|
||||
# Batch insert (Milvus_ids, img_path) to mysql
|
||||
self.test_connection()
|
||||
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||
try:
|
||||
self.cursor.executemany(sql, data)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(
|
||||
f"MYSQL loads data to table: {table_name} successfully")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_by_milvus_ids(self, ids, table_name):
|
||||
# Get the img_path according to the milvus ids
|
||||
self.test_connection()
|
||||
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
results = [res[0] for res in results]
|
||||
LOGGER.debug("MYSQL search by milvus id.")
|
||||
return results
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_table(self, table_name):
|
||||
# Delete mysql table if exists
|
||||
self.test_connection()
|
||||
sql = "drop table if exists " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_all_data(self, table_name):
|
||||
# Delete all the data in mysql table
|
||||
self.test_connection()
|
||||
sql = 'delete from ' + table_name + ';'
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def count_table(self, table_name):
|
||||
# Get the number of mysql table
|
||||
self.test_connection()
|
||||
sql = "select count(milvus_id) from " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||
return results[0][0]
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_count(table_name, milvus_cli):
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return None
|
||||
num = milvus_cli.count(table_name)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to count table {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return "Collection is not exist"
|
||||
status = milvus_cli.delete_collection(table_name)
|
||||
mysql_cli.delete_table(table_name)
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from diskcache import Cache
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audios(path):
|
||||
"""
|
||||
List all wav and aif files recursively under the path folder.
|
||||
"""
|
||||
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||
return [
|
||||
item for sublist in [[os.path.join(dir, file) for file in files]
|
||||
for dir, _, files in list(os.walk(path))]
|
||||
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||
]
|
||||
|
||||
|
||||
def extract_features(audio_dir):
|
||||
"""
|
||||
Get the vector of audio
|
||||
"""
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
feats = []
|
||||
names = []
|
||||
audio_list = get_audios(audio_dir)
|
||||
total = len(audio_list)
|
||||
cache['total'] = total
|
||||
for i, audio_path in enumerate(audio_list):
|
||||
norm_feat = get_audio_embedding(audio_path)
|
||||
if norm_feat is None:
|
||||
continue
|
||||
feats.append(norm_feat)
|
||||
names.append(audio_path.encode())
|
||||
cache['current'] = i + 1
|
||||
print(
|
||||
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||
)
|
||||
return feats, names
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def format_data(ids, names):
|
||||
"""
|
||||
Combine the id of the vector and the name of the audio into a list
|
||||
"""
|
||||
data = []
|
||||
for i in range(len(ids)):
|
||||
value = (str(ids[i]), names[i])
|
||||
data.append(value)
|
||||
return data
|
||||
|
||||
|
||||
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Import vectors to Milvus and data to Mysql respectively
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
vectors, names = extract_features(audio_dir)
|
||||
ids = milvus_cli.insert(table_name, vectors)
|
||||
milvus_cli.create_index(table_name)
|
||||
mysql_cli.create_mysql_table(table_name)
|
||||
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||
return len(ids)
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from config import TOP_K
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
try:
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
feat = get_audio_embedding(audio_path)
|
||||
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||
vids = [str(x.id) for x in vectors[0]]
|
||||
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||
distances = [x.distance for x in vectors[0]]
|
||||
for i in range(len(paths)):
|
||||
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||
paths[i] = tmp
|
||||
distances[i] = (1 - distances[i]) * 100
|
||||
return vids, paths, distances
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with search: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
download audio data
|
||||
"""
|
||||
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||
gdown.download(url)
|
||||
|
||||
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||
zip_ref.extractall('./example_audio')
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
response = client.post("/audio/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_load():
|
||||
"""
|
||||
Insert all the audio files under the file path to Milvus/MySQL
|
||||
"""
|
||||
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully loaded data!"
|
||||
}
|
||||
|
||||
|
||||
def test_progress():
|
||||
"""
|
||||
Get the progress of dealing with data
|
||||
"""
|
||||
response = client.get("/progress")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "current: 20, total: 20"
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
response = client.get("audio/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == 20
|
||||
|
||||
|
||||
def test_search():
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
response = client.post(
|
||||
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_data():
|
||||
"""
|
||||
Get the audio file
|
||||
"""
|
||||
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
test_load()
|
||||
test_count()
|
||||
test_search()
|
||||
test_drop()
|
@ -0,0 +1 @@
|
||||
*.wav
|
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
|
||||
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav --topk 1
|
@ -0,0 +1,148 @@
|
||||
# HiFiGAN with the LJSpeech-1.1
|
||||
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
|
||||
## Dataset
|
||||
### Download and Extract
|
||||
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/).
|
||||
### Get MFA Result and Extract
|
||||
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
|
||||
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
|
||||
|
||||
## Get Started
|
||||
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
|
||||
Assume the path to the MFA result of LJSpeech-1.1 is `./ljspeech_alignment`.
|
||||
Run the command below to
|
||||
1. **source path**.
|
||||
2. preprocess the dataset.
|
||||
3. train the model.
|
||||
4. synthesize wavs.
|
||||
- synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
|
||||
```bash
|
||||
./run.sh --stage 0 --stop-stage 0
|
||||
```
|
||||
### Data Preprocessing
|
||||
```bash
|
||||
./local/preprocess.sh ${conf_path}
|
||||
```
|
||||
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||
|
||||
```text
|
||||
dump
|
||||
├── dev
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
├── test
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
└── train
|
||||
├── norm
|
||||
├── raw
|
||||
└── feats_stats.npy
|
||||
```
|
||||
|
||||
The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
|
||||
|
||||
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
|
||||
|
||||
### Model Training
|
||||
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||
```
|
||||
Here's the complete help message.
|
||||
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
|
||||
[--run-benchmark RUN_BENCHMARK]
|
||||
[--profiler_options PROFILER_OPTIONS]
|
||||
|
||||
Train a ParallelWaveGAN model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG config file to overwrite default config.
|
||||
--train-metadata TRAIN_METADATA
|
||||
training data.
|
||||
--dev-metadata DEV_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
|
||||
benchmark:
|
||||
arguments related to benchmark.
|
||||
|
||||
--batch-size BATCH_SIZE
|
||||
batch size.
|
||||
--max-iter MAX_ITER train max steps.
|
||||
--run-benchmark RUN_BENCHMARK
|
||||
runing benchmark or not, if True, use the --batch-size
|
||||
and --max-iter.
|
||||
--profiler_options PROFILER_OPTIONS
|
||||
The option of profiler, which should be in format
|
||||
"key1=value1;key2=value2;key3=value3".
|
||||
```
|
||||
|
||||
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
|
||||
### Synthesizing
|
||||
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
|
||||
[--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
|
||||
[--output-dir OUTPUT_DIR] [--ngpu NGPU]
|
||||
|
||||
Synthesize with GANVocoder.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--generator-type GENERATOR_TYPE
|
||||
type of GANVocoder, should in {pwgan, mb_melgan,
|
||||
style_melgan, } now
|
||||
--config CONFIG GANVocoder config file.
|
||||
--checkpoint CHECKPOINT
|
||||
snapshot to load.
|
||||
--test-metadata TEST_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
```
|
||||
|
||||
1. `--config` parallel wavegan config file. You should use the same config with which the model is trained.
|
||||
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
|
||||
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
|
||||
4. `--output-dir` is the directory to save the synthesized audio files.
|
||||
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
|
||||
## Pretrained Model
|
||||
The pretrained model can be downloaded here [hifigan_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip).
|
||||
|
||||
|
||||
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
|
||||
:-------------:| :------------:| :-----: | :-----: | :--------:
|
||||
default| 1(gpu) x 2500000|24.492|0.115|7.227
|
||||
|
||||
HiFiGAN checkpoint contains files listed below.
|
||||
|
||||
```text
|
||||
hifigan_ljspeech_ckpt_0.2.0
|
||||
├── default.yaml # default config used to train hifigan
|
||||
├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan
|
||||
└── snapshot_iter_2500000.pdz # generator parameters of hifigan
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
|
@ -0,0 +1,167 @@
|
||||
# This is the configuration file for LJSpeech dataset.
|
||||
# This configuration is based on HiFiGAN V1, which is an official configuration.
|
||||
# But I found that the optimizer setting does not work well with my implementation.
|
||||
# So I changed optimizer settings as follows:
|
||||
# - AdamW -> Adam
|
||||
# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
|
||||
# - Scheduler: ExponentialLR -> MultiStepLR
|
||||
# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting.
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
fs: 22050 # Sampling rate.
|
||||
n_fft: 1024 # FFT size (samples).
|
||||
n_shift: 256 # Hop size (samples). 11.6ms
|
||||
win_length: null # Window length (samples).
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
n_mels: 80 # Number of mel basis.
|
||||
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||
|
||||
###########################################################
|
||||
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
generator_params:
|
||||
in_channels: 80 # Number of input channels.
|
||||
out_channels: 1 # Number of output channels.
|
||||
channels: 512 # Number of initial channels.
|
||||
kernel_size: 7 # Kernel size of initial and final conv layers.
|
||||
upsample_scales: [8, 8, 2, 2] # Upsampling scales.
|
||||
upsample_kernel_sizes: [16, 16, 4, 4] # Kernel size for upsampling layers.
|
||||
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
|
||||
resblock_dilations: # Dilations for residual blocks.
|
||||
- [1, 3, 5]
|
||||
- [1, 3, 5]
|
||||
- [1, 3, 5]
|
||||
use_additional_convs: True # Whether to use additional conv layer in residual blocks.
|
||||
bias: True # Whether to use bias parameter in conv.
|
||||
nonlinear_activation: "leakyrelu" # Nonlinear activation type.
|
||||
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||
negative_slope: 0.1
|
||||
use_weight_norm: True # Whether to apply weight normalization.
|
||||
|
||||
|
||||
###########################################################
|
||||
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
discriminator_params:
|
||||
scales: 3 # Number of multi-scale discriminator.
|
||||
scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
|
||||
scale_downsample_pooling_params:
|
||||
kernel_size: 4 # Pooling kernel size.
|
||||
stride: 2 # Pooling stride.
|
||||
padding: 2 # Padding size.
|
||||
scale_discriminator_params:
|
||||
in_channels: 1 # Number of input channels.
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
|
||||
channels: 128 # Initial number of channels.
|
||||
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||
max_groups: 16 # Maximum number of groups in downsampling conv layers.
|
||||
bias: True
|
||||
downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
|
||||
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||
nonlinear_activation_params:
|
||||
negative_slope: 0.1
|
||||
follow_official_norm: True # Whether to follow the official norm setting.
|
||||
periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
|
||||
period_discriminator_params:
|
||||
in_channels: 1 # Number of input channels.
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_sizes: [5, 3] # List of kernel sizes.
|
||||
channels: 32 # Initial number of channels.
|
||||
downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
|
||||
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||
bias: True # Whether to use bias parameter in conv layer."
|
||||
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||
negative_slope: 0.1
|
||||
use_weight_norm: True # Whether to apply weight normalization.
|
||||
use_spectral_norm: False # Whether to apply spectral normalization.
|
||||
|
||||
|
||||
###########################################################
|
||||
# STFT LOSS SETTING #
|
||||
###########################################################
|
||||
use_stft_loss: False # Whether to use multi-resolution STFT loss.
|
||||
use_mel_loss: True # Whether to use Mel-spectrogram loss.
|
||||
mel_loss_params:
|
||||
fs: 22050
|
||||
fft_size: 1024
|
||||
hop_size: 256
|
||||
win_length: null
|
||||
window: "hann"
|
||||
num_mels: 80
|
||||
fmin: 0
|
||||
fmax: 11025
|
||||
log_base: null
|
||||
generator_adv_loss_params:
|
||||
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||
discriminator_adv_loss_params:
|
||||
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||
use_feat_match_loss: True
|
||||
feat_match_loss_params:
|
||||
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||
average_by_layers: False # Whether to average loss by #layers in each discriminator.
|
||||
include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
|
||||
|
||||
###########################################################
|
||||
# ADVERSARIAL LOSS SETTING #
|
||||
###########################################################
|
||||
lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
|
||||
lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
|
||||
lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
|
||||
|
||||
###########################################################
|
||||
# DATA LOADER SETTING #
|
||||
###########################################################
|
||||
batch_size: 16 # Batch size.
|
||||
batch_max_steps: 8192 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||
num_workers: 2 # Number of workers in DataLoader.
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER & SCHEDULER SETTING #
|
||||
###########################################################
|
||||
generator_optimizer_params:
|
||||
beta1: 0.5
|
||||
beta2: 0.9
|
||||
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||
generator_scheduler_params:
|
||||
learning_rate: 2.0e-4 # Generator's learning rate.
|
||||
gamma: 0.5 # Generator's scheduler gamma.
|
||||
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||
- 200000
|
||||
- 400000
|
||||
- 600000
|
||||
- 800000
|
||||
generator_grad_norm: -1 # Generator's gradient norm.
|
||||
discriminator_optimizer_params:
|
||||
beta1: 0.5
|
||||
beta2: 0.9
|
||||
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||
discriminator_scheduler_params:
|
||||
learning_rate: 2.0e-4 # Discriminator's learning rate.
|
||||
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||
- 200000
|
||||
- 400000
|
||||
- 600000
|
||||
- 800000
|
||||
discriminator_grad_norm: -1 # Discriminator's gradient norm.
|
||||
|
||||
###########################################################
|
||||
# INTERVAL SETTING #
|
||||
###########################################################
|
||||
generator_train_start_steps: 1 # Number of steps to start to train discriminator.
|
||||
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
|
||||
train_max_steps: 2500000 # Number of training steps.
|
||||
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
num_snapshots: 10 # max number of snapshots to keep while training
|
||||
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./ljspeech_alignment \
|
||||
--output=durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/../preprocess.py \
|
||||
--rootdir=~/datasets/LJSpeech-1.1/ \
|
||||
--dataset=ljspeech \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--cut-sil=True \
|
||||
--num-cpu=20
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="feats"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
|
||||
python3 ${BIN_DIR}/../normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--stats=dump/train/feats_stats.npy
|
||||
fi
|
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize.py \
|
||||
--config=${config_path} \
|
||||
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||
--output-dir=${train_output_path}/test \
|
||||
--generator-type=hifigan
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
FLAGS_cudnn_exhaustive_search=true \
|
||||
FLAGS_conv_workspace_size_limit=4000 \
|
||||
python ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=hifigan
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_5000.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,2 @@
|
||||
.eggs
|
||||
*.wav
|
@ -1,5 +1,9 @@
|
||||
# Changelog
|
||||
|
||||
Date: 2022-3-15, Author: Xiaojie Chen.
|
||||
- kaldi and librosa mfcc, fbank, spectrogram.
|
||||
- unit test and benchmark.
|
||||
|
||||
Date: 2022-2-25, Author: Hui Zhang.
|
||||
- Refactor architecture.
|
||||
- dtw distance and mcd style dtw
|
||||
- dtw distance and mcd style dtw.
|
||||
|
@ -0,0 +1,19 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@ -0,0 +1,18 @@
|
||||
# Build docs for PaddleAudio
|
||||
|
||||
## 1. Install
|
||||
|
||||
`pip install Sphinx`
|
||||
`pip install sphinx_rtd_theme`
|
||||
|
||||
|
||||
## 2. Generate API docs
|
||||
|
||||
Exclude `paddleaudio.utils`
|
||||
|
||||
`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates`
|
||||
|
||||
|
||||
## 3. Build
|
||||
|
||||
`sphinx-build source _html`
|
After Width: | Height: | Size: 4.9 KiB |
@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
|
||||
:end
|
||||
popd
|
@ -0,0 +1,5 @@
|
||||
.wy-nav-content {
|
||||
max-width: 80%;
|
||||
}
|
||||
.table table{ background:#b9b9b9}
|
||||
.table table td{ background:#FFF; }
|
@ -0,0 +1,9 @@
|
||||
{%- if show_headings %}
|
||||
{{- basename | e | heading }}
|
||||
|
||||
{% endif -%}
|
||||
.. automodule:: {{ qualname }}
|
||||
{%- for option in automodule_options %}
|
||||
:{{ option }}:
|
||||
{%- endfor %}
|
||||
|
@ -0,0 +1,57 @@
|
||||
{%- macro automodule(modname, options) -%}
|
||||
.. automodule:: {{ modname }}
|
||||
{%- for option in options %}
|
||||
:{{ option }}:
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- macro toctree(docnames) -%}
|
||||
.. toctree::
|
||||
:maxdepth: {{ maxdepth }}
|
||||
{% for docname in docnames %}
|
||||
{{ docname }}
|
||||
{%- endfor %}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- if is_namespace %}
|
||||
{{- [pkgname, "namespace"] | join(" ") | e | heading }}
|
||||
{% else %}
|
||||
{{- pkgname | e | heading }}
|
||||
{% endif %}
|
||||
|
||||
{%- if is_namespace %}
|
||||
.. py:module:: {{ pkgname }}
|
||||
{% endif %}
|
||||
|
||||
{%- if modulefirst and not is_namespace %}
|
||||
{{ automodule(pkgname, automodule_options) }}
|
||||
{% endif %}
|
||||
|
||||
{%- if subpackages %}
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
{{ toctree(subpackages) }}
|
||||
{% endif %}
|
||||
|
||||
{%- if submodules %}
|
||||
Submodules
|
||||
----------
|
||||
{% if separatemodules %}
|
||||
{{ toctree(submodules) }}
|
||||
{% else %}
|
||||
{%- for submodule in submodules %}
|
||||
{% if show_headings %}
|
||||
{{- submodule | e | heading(2) }}
|
||||
{% endif %}
|
||||
{{ automodule(submodule, automodule_options) }}
|
||||
{% endfor %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if not modulefirst and not is_namespace %}
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
{{ automodule(pkgname, automodule_options) }}
|
||||
{% endif %}
|
@ -0,0 +1,8 @@
|
||||
{{ header | heading }}
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: {{ maxdepth }}
|
||||
{% for docname in docnames %}
|
||||
{{ docname }}
|
||||
{%- endfor %}
|
||||
|
@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file does only contain a selection of the most common options. For a
|
||||
# full list see the documentation:
|
||||
# http://www.sphinx-doc.org/en/master/config
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'PaddleAudio'
|
||||
copyright = '2022, PaddlePaddle'
|
||||
author = 'PaddlePaddle'
|
||||
|
||||
# The short X.Y version
|
||||
version = ''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = '0.2.0'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.napoleon',
|
||||
]
|
||||
|
||||
napoleon_google_docstring = True
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
|
||||
import sphinx_rtd_theme
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||
smartquotes = False
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
# html_theme_options = {}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_logo = '../images/paddle.png'
|
||||
html_css_files = [
|
||||
'custom.css',
|
||||
]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# The default sidebars (for documents that don't match any pattern) are
|
||||
# defined by theme itself. Builtin themes are using these templates by
|
||||
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'PaddleAudiodoc'
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'PaddleAudio.tex', 'PaddleAudio Documentation', 'PaddlePaddle',
|
||||
'manual'),
|
||||
]
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, 'paddleaudio', 'PaddleAudio Documentation', [author],
|
||||
1)]
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'PaddleAudio', 'PaddleAudio Documentation', author,
|
||||
'PaddleAudio', 'One line description of project.', 'Miscellaneous'),
|
||||
]
|
||||
|
||||
# -- Options for Epub output -------------------------------------------------
|
||||
|
||||
# Bibliographic Dublin Core info.
|
||||
epub_title = project
|
||||
|
||||
# The unique identifier of the text. This can be a ISBN number
|
||||
# or the project homepage.
|
||||
#
|
||||
# epub_identifier = ''
|
||||
|
||||
# A unique identification for the text.
|
||||
#
|
||||
# epub_uid = ''
|
||||
|
||||
# A list of files that should not be packed into the epub file.
|
||||
epub_exclude_files = ['search.html']
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
|
||||
# -- Options for intersphinx extension ---------------------------------------
|
||||
|
||||
# Example configuration for intersphinx: refer to the Python standard library.
|
||||
intersphinx_mapping = {'https://docs.python.org/': None}
|
@ -0,0 +1,22 @@
|
||||
.. PaddleAudio documentation master file, created by
|
||||
sphinx-quickstart on Tue Mar 22 15:57:16 2022.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to PaddleAudio's documentation!
|
||||
=======================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
Index <self>
|
||||
|
||||
|
||||
API References
|
||||
--------------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:titlesonly:
|
||||
|
||||
paddleaudio
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import unittest
|
||||
import urllib.request
|
||||
|
||||
mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav'
|
||||
|
||||
|
||||
class BackendTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initWavInput()
|
||||
|
||||
def initWavInput(self):
|
||||
self.files = []
|
||||
for url in [mono_channel_wav, multi_channels_wav]:
|
||||
if not os.path.isfile(os.path.basename(url)):
|
||||
urllib.request.urlretrieve(url, os.path.basename(url))
|
||||
self.files.append(os.path.basename(url))
|
||||
|
||||
def initParmas(self):
|
||||
raise NotImplementedError
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import filecmp
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
import paddleaudio
|
||||
from ..base import BackendTest
|
||||
|
||||
|
||||
class TestIO(BackendTest):
|
||||
def test_load_mono_channel(self):
|
||||
sf_data, sf_sr = sf.read(self.files[0])
|
||||
pa_data, pa_sr = paddleaudio.load(
|
||||
self.files[0], normal=False, dtype='float64')
|
||||
|
||||
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||
self.assertEqual(sf_sr, pa_sr)
|
||||
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||
|
||||
def test_load_multi_channels(self):
|
||||
sf_data, sf_sr = sf.read(self.files[1])
|
||||
sf_data = sf_data.T # Channel dim first
|
||||
pa_data, pa_sr = paddleaudio.load(
|
||||
self.files[1], mono=False, normal=False, dtype='float64')
|
||||
|
||||
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||
self.assertEqual(sf_sr, pa_sr)
|
||||
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||
|
||||
def test_save_mono_channel(self):
|
||||
waveform, sr = np.random.randint(
|
||||
low=-32768, high=32768, size=(48000), dtype=np.int16), 16000
|
||||
sf_tmp_file = 'sf_tmp.wav'
|
||||
pa_tmp_file = 'pa_tmp.wav'
|
||||
|
||||
sf.write(sf_tmp_file, waveform, sr)
|
||||
paddleaudio.save(waveform, sr, pa_tmp_file)
|
||||
|
||||
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||
for file in [sf_tmp_file, pa_tmp_file]:
|
||||
os.remove(file)
|
||||
|
||||
def test_save_multi_channels(self):
|
||||
waveform, sr = np.random.randint(
|
||||
low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000
|
||||
sf_tmp_file = 'sf_tmp.wav'
|
||||
pa_tmp_file = 'pa_tmp.wav'
|
||||
|
||||
sf.write(sf_tmp_file, waveform.T, sr)
|
||||
paddleaudio.save(waveform.T, sr, pa_tmp_file)
|
||||
|
||||
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||
for file in [sf_tmp_file, pa_tmp_file]:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,39 @@
|
||||
# 1. Prepare
|
||||
First, install `pytest-benchmark` via pip.
|
||||
```sh
|
||||
pip install pytest-benchmark
|
||||
```
|
||||
|
||||
# 2. Run
|
||||
Run the specific script for profiling.
|
||||
```sh
|
||||
pytest melspectrogram.py
|
||||
```
|
||||
|
||||
Result:
|
||||
```sh
|
||||
========================================================================== test session starts ==========================================================================
|
||||
platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0
|
||||
benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
|
||||
rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio
|
||||
plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0
|
||||
collected 4 items
|
||||
|
||||
melspectrogram.py .... [100%]
|
||||
|
||||
|
||||
-------------------------------------------------------------------------------------------------- benchmark: 4 tests -------------------------------------------------------------------------------------------------
|
||||
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1
|
||||
test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1
|
||||
test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1
|
||||
test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Legend:
|
||||
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
|
||||
OPS: Operations Per Second, computed as 1 / Mean
|
||||
========================================================================== 4 passed in 21.12s ===========================================================================
|
||||
|
||||
```
|
@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
log_mel_extractor = paddleaudio.features.LogMelSpectrogram(
|
||||
**mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def log_melspectrogram():
|
||||
return log_mel_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_log_melspect_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_log_melspect_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=2)
|
||||
|
||||
|
||||
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||
**mel_conf_torchaudio, f_min=0.0)
|
||||
amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0)
|
||||
|
||||
|
||||
def melspectrogram_torchaudio():
|
||||
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def log_melspectrogram_torchaudio():
|
||||
mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch)
|
||||
return amplitude_to_DB(mel_specgram).squeeze(0)
|
||||
|
||||
|
||||
def test_log_melspect_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
amplitude_to_DB = amplitude_to_DB.to('cpu')
|
||||
|
||||
feature_paddleaudio = benchmark(log_melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_log_melspect_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
amplitude_to_DB = amplitude_to_DB.to('cuda')
|
||||
|
||||
feature_torchaudio = benchmark(log_melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=2)
|
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
mel_extractor = paddleaudio.features.MelSpectrogram(
|
||||
**mel_conf, f_min=0.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def melspectrogram():
|
||||
return mel_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_melspect_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_melspect_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||
**mel_conf_torchaudio, f_min=0.0)
|
||||
|
||||
|
||||
def melspectrogram_torchaudio():
|
||||
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def test_melspect_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
feature_paddleaudio = benchmark(melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_melspect_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
feature_torchaudio = benchmark(melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=3)
|
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
mfcc_conf = {
|
||||
'n_mfcc': 20,
|
||||
'top_db': 80.0,
|
||||
}
|
||||
mfcc_conf.update(mel_conf)
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
mfcc_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_mfcc': 20,
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
mfcc_extractor = paddleaudio.features.MFCC(
|
||||
**mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def mfcc():
|
||||
return mfcc_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_mfcc_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(mfcc)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_mfcc_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(mfcc)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
del mel_conf_torchaudio['sample_rate']
|
||||
mfcc_extractor_torchaudio = torchaudio.transforms.MFCC(
|
||||
**mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio)
|
||||
|
||||
|
||||
def mfcc_torchaudio():
|
||||
return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def test_mfcc_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mfcc_extractor_torchaudio
|
||||
|
||||
mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
|
||||
feature_paddleaudio = benchmark(mfcc_torchaudio)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_mfcc_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mfcc_extractor_torchaudio
|
||||
|
||||
mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
|
||||
feature_torchaudio = benchmark(mfcc_torchaudio)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=3)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import unittest
|
||||
import urllib.request
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddleaudio import load
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
|
||||
|
||||
class FeatTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initParmas()
|
||||
self.initWavInput()
|
||||
self.setUpDevice()
|
||||
|
||||
def setUpDevice(self, device='cpu'):
|
||||
paddle.set_device(device)
|
||||
|
||||
def initWavInput(self, url=wav_url):
|
||||
if not os.path.isfile(os.path.basename(url)):
|
||||
urllib.request.urlretrieve(url, os.path.basename(url))
|
||||
self.waveform, self.sr = load(os.path.abspath(os.path.basename(url)))
|
||||
self.waveform = self.waveform.astype(
|
||||
np.float32
|
||||
) # paddlespeech.s2t.transform.spectrogram only supports float32
|
||||
dim = len(self.waveform.shape)
|
||||
|
||||
assert dim in [1, 2]
|
||||
if dim == 1:
|
||||
self.waveform = np.expand_dims(self.waveform, 0)
|
||||
|
||||
def initParmas(self):
|
||||
raise NotImplementedError
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
from paddlespeech.s2t.transform.spectrogram import IStft
|
||||
from paddlespeech.s2t.transform.spectrogram import Stft
|
||||
|
||||
|
||||
class TestIstft(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.window_str = 'hann'
|
||||
|
||||
def test_istft(self):
|
||||
ps_stft = Stft(self.n_fft, self.hop_length)
|
||||
ps_res = ps_stft(
|
||||
self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
|
||||
x = paddle.to_tensor(ps_res)
|
||||
|
||||
ps_istft = IStft(self.hop_length)
|
||||
ps_res = ps_istft(ps_res.T)
|
||||
|
||||
window = get_window(
|
||||
self.window_str, self.n_fft, dtype=self.waveform.dtype)
|
||||
pd_res = paddle.signal.istft(
|
||||
x, self.n_fft, self.hop_length, window=window)
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
|
||||
|
||||
class TestKaldi(FeatTest):
|
||||
def initParmas(self):
|
||||
self.window_size = 1024
|
||||
self.dtype = 'float32'
|
||||
|
||||
def test_window(self):
|
||||
t_hann_window = torch.hann_window(
|
||||
self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}'))
|
||||
t_hamm_window = torch.hamming_window(
|
||||
self.window_size,
|
||||
periodic=False,
|
||||
alpha=0.54,
|
||||
beta=0.46,
|
||||
dtype=eval(f'torch.{self.dtype}'))
|
||||
t_povey_window = torch.hann_window(
|
||||
self.window_size, periodic=False,
|
||||
dtype=eval(f'torch.{self.dtype}')).pow(0.85)
|
||||
|
||||
p_hann_window = paddleaudio.functional.window.get_window(
|
||||
'hann',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}'))
|
||||
p_hamm_window = paddleaudio.functional.window.get_window(
|
||||
'hamming',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}'))
|
||||
p_povey_window = paddleaudio.functional.window.get_window(
|
||||
'hann',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}')).pow(0.85)
|
||||
|
||||
np.testing.assert_array_almost_equal(t_hann_window, p_hann_window)
|
||||
np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window)
|
||||
np.testing.assert_array_almost_equal(t_povey_window, p_povey_window)
|
||||
|
||||
def test_fbank(self):
|
||||
ta_features = torchaudio.compliance.kaldi.fbank(
|
||||
torch.from_numpy(self.waveform.astype(self.dtype)))
|
||||
pa_features = paddleaudio.compliance.kaldi.fbank(
|
||||
paddle.to_tensor(self.waveform.astype(self.dtype)))
|
||||
np.testing.assert_array_almost_equal(
|
||||
ta_features, pa_features, decimal=4)
|
||||
|
||||
def test_mfcc(self):
|
||||
ta_features = torchaudio.compliance.kaldi.mfcc(
|
||||
torch.from_numpy(self.waveform.astype(self.dtype)))
|
||||
pa_features = paddleaudio.compliance.kaldi.mfcc(
|
||||
paddle.to_tensor(self.waveform.astype(self.dtype)))
|
||||
np.testing.assert_array_almost_equal(
|
||||
ta_features, pa_features, decimal=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,281 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
|
||||
|
||||
class TestLibrosa(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.n_mels = 40
|
||||
self.n_mfcc = 20
|
||||
self.fmin = 0.0
|
||||
self.window_str = 'hann'
|
||||
self.pad_mode = 'reflect'
|
||||
self.top_db = 80.0
|
||||
|
||||
def test_stft(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
feature_librosa = librosa.core.stft(
|
||||
y=self.waveform,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
dtype=None,
|
||||
pad_mode=self.pad_mode, )
|
||||
x = paddle.to_tensor(self.waveform).unsqueeze(0)
|
||||
window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
|
||||
feature_paddle = paddle.signal.stft(
|
||||
x=x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=window,
|
||||
center=True,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=False,
|
||||
onesided=True, ).squeeze(0)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddle, decimal=5)
|
||||
|
||||
def test_istft(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# Get stft result from librosa.
|
||||
stft_matrix = librosa.core.stft(
|
||||
y=self.waveform,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
pad_mode=self.pad_mode, )
|
||||
|
||||
feature_librosa = librosa.core.istft(
|
||||
stft_matrix=stft_matrix,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
dtype=None,
|
||||
length=None, )
|
||||
|
||||
x = paddle.to_tensor(stft_matrix).unsqueeze(0)
|
||||
window = get_window(
|
||||
self.window_str,
|
||||
self.n_fft,
|
||||
dtype=paddle.to_tensor(self.waveform).dtype)
|
||||
feature_paddle = paddle.signal.istft(
|
||||
x=x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=window,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
length=None,
|
||||
return_complex=False, ).squeeze(0)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddle, decimal=5)
|
||||
|
||||
def test_mel(self):
|
||||
feature_librosa = librosa.filters.mel(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
fmax=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=self.waveform.dtype, )
|
||||
feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
fmax=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=self.waveform.dtype, )
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
feature_functional = paddleaudio.functional.compute_fbank_matrix(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
f_max=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=x.dtype, )
|
||||
|
||||
np.testing.assert_array_almost_equal(feature_librosa,
|
||||
feature_compliance)
|
||||
np.testing.assert_array_almost_equal(feature_librosa,
|
||||
feature_functional)
|
||||
|
||||
def test_melspect(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.melspectrogram(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
to_db=False)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.MelSpectrogram(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=5)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=5)
|
||||
|
||||
def test_log_melspect(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.melspectrogram(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=None)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.LogMelSpectrogram(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=5)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=4)
|
||||
|
||||
def test_mfcc(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.mfcc(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
S=None,
|
||||
n_mfcc=self.n_mfcc,
|
||||
dct_type=2,
|
||||
norm='ortho',
|
||||
lifter=0,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.mfcc(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
n_mfcc=self.n_mfcc,
|
||||
dct_type=2,
|
||||
norm='ortho',
|
||||
lifter=0,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
top_db=self.top_db)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.MFCC(
|
||||
sr=self.sr,
|
||||
n_mfcc=self.n_mfcc,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
top_db=self.top_db,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=4)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram
|
||||
|
||||
|
||||
class TestLogMelSpectrogram(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.n_mels = 40
|
||||
|
||||
def test_log_melspect(self):
|
||||
ps_melspect = LogMelSpectrogram(self.sr, self.n_mels, self.n_fft,
|
||||
self.hop_length)
|
||||
ps_res = ps_melspect(self.waveform.T).squeeze(1).T
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
# paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况
|
||||
ps_melspect = paddleaudio.features.LogMelSpectrogram(
|
||||
self.sr,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
power=1.0,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0.0)
|
||||
pa_res = (ps_melspect(x) / 10.0).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddlespeech.s2t.transform.spectrogram import Spectrogram
|
||||
|
||||
|
||||
class TestSpectrogram(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
|
||||
def test_spectrogram(self):
|
||||
ps_spect = Spectrogram(self.n_fft, self.hop_length)
|
||||
ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
pa_spect = paddleaudio.features.Spectrogram(
|
||||
self.n_fft, self.hop_length, power=1.0)
|
||||
pa_res = pa_spect(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
from paddlespeech.s2t.transform.spectrogram import Stft
|
||||
|
||||
|
||||
class TestStft(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.window_str = 'hann'
|
||||
|
||||
def test_stft(self):
|
||||
ps_stft = Stft(self.n_fft, self.hop_length)
|
||||
ps_res = ps_stft(
|
||||
self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
|
||||
pd_res = paddle.signal.stft(
|
||||
x, self.n_fft, self.hop_length, window=window).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue