pull/1601/head
commit
490300f84f
@ -0,0 +1,88 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.0
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.0.1
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.10
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
ports:
|
||||
- "19530:19530"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
mysql:
|
||||
container_name: audio-mysql
|
||||
image: mysql:5.7
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.11
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=123456
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||
ports:
|
||||
- "3306:3306"
|
||||
|
||||
webclient:
|
||||
container_name: audio-webclient
|
||||
image: qingen1/paddlespeech-audio-search-client:2.3
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.13
|
||||
environment:
|
||||
API_URL: 'http://127.0.0.1:8002'
|
||||
ports:
|
||||
- "8068:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
networks:
|
||||
app_net:
|
||||
driver: bridge
|
||||
ipam:
|
||||
driver: default
|
||||
config:
|
||||
- subnet: 172.16.23.0/24
|
||||
gateway: 172.16.23.1
|
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 80 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 84 KiB |
@ -0,0 +1,12 @@
|
||||
soundfile==0.10.3.post1
|
||||
librosa==0.8.0
|
||||
numpy
|
||||
pymysql
|
||||
fastapi
|
||||
uvicorn
|
||||
diskcache==5.2.1
|
||||
pymilvus==2.0.1
|
||||
python-multipart
|
||||
typing
|
||||
starlette
|
||||
pydantic
|
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
############### Milvus Configuration ###############
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||
|
||||
############### MySQL Configuration ###############
|
||||
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||
|
||||
############### Data Path ###############
|
||||
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||
|
||||
############### Number of Log Files ###############
|
||||
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audio_embedding(path):
|
||||
"""
|
||||
Use vpr_inference to generate embedding of audio
|
||||
"""
|
||||
try:
|
||||
RESAMPLE_RATE = 16000
|
||||
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||
|
||||
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||
# embedding = vpr.inference(audio)
|
||||
np.random.seed(hash(os.path.basename(path)) % 1000000)
|
||||
embedding = np.random.rand(1, 2048)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
embedding = embedding.tolist()[0]
|
||||
return embedding
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with embedding:{e}")
|
||||
return None
|
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from diskcache import Cache
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from milvus_helpers import MilvusHelper
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count
|
||||
from operations.drop import do_drop
|
||||
from operations.load import do_load
|
||||
from operations.search import do_search
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MODEL = None
|
||||
MILVUS_CLI = MilvusHelper()
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file
|
||||
try:
|
||||
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"upload audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/progress')
|
||||
def get_progress():
|
||||
# Get the progress of dealing with data
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
return f"current: {cache['current']}, total: {cache['total']}"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Upload data error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
Table: Optional[str] = None
|
||||
File: str
|
||||
|
||||
|
||||
@app.post('/audio/load')
|
||||
async def load_audios(item: Item):
|
||||
# Insert all the audio files under the file path to Milvus/MySQL
|
||||
try:
|
||||
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search')
|
||||
async def search_audio(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search/local')
|
||||
async def search_local_audio(request: Request,
|
||||
query_audio_path: str,
|
||||
table_name: str=None):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/audio/count')
|
||||
async def count_audio(table_name: str=None):
|
||||
# Returns the total number of vectors in the system
|
||||
try:
|
||||
num = do_count(table_name, MILVUS_CLI)
|
||||
LOGGER.info("Successfully count the number of data!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the collection of Milvus and MySQL
|
||||
try:
|
||||
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import METRIC_TYPE
|
||||
from config import MILVUS_HOST
|
||||
from config import MILVUS_PORT
|
||||
from config import VECTOR_DIMENSION
|
||||
from logs import LOGGER
|
||||
from pymilvus import Collection
|
||||
from pymilvus import CollectionSchema
|
||||
from pymilvus import connections
|
||||
from pymilvus import DataType
|
||||
from pymilvus import FieldSchema
|
||||
from pymilvus import utility
|
||||
|
||||
|
||||
class MilvusHelper:
|
||||
"""
|
||||
the basic operations of PyMilvus
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to Milvus server
|
||||
# 2. create a collection
|
||||
# 3. insert entities
|
||||
# 4. create index
|
||||
# 5. search
|
||||
# 6. delete a collection
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.collection = None
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
LOGGER.debug(
|
||||
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
try:
|
||||
if self.has_collection(collection_name):
|
||||
self.collection = Collection(name=collection_name)
|
||||
else:
|
||||
raise Exception(
|
||||
f"There is no collection named:{collection_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to set collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def has_collection(self, collection_name):
|
||||
# Return if Milvus has the collection
|
||||
try:
|
||||
return utility.has_collection(collection_name)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_collection(self, collection_name):
|
||||
# Create milvus collection if not exists
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
field1 = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
descrition="int64",
|
||||
is_primary=True,
|
||||
auto_id=True)
|
||||
field2 = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
descrition="speaker embeddings",
|
||||
dim=VECTOR_DIMENSION,
|
||||
is_primary=False)
|
||||
schema = CollectionSchema(
|
||||
fields=[field1, field2], description="embeddings info")
|
||||
self.collection = Collection(
|
||||
name=collection_name, schema=schema)
|
||||
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||
else:
|
||||
self.set_collection(collection_name)
|
||||
return "OK"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def insert(self, collection_name, vectors):
|
||||
# Batch insert vectors to milvus collection
|
||||
try:
|
||||
self.create_collection(collection_name)
|
||||
data = [vectors]
|
||||
self.set_collection(collection_name)
|
||||
mr = self.collection.insert(data)
|
||||
ids = mr.primary_keys
|
||||
self.collection.load()
|
||||
LOGGER.debug(
|
||||
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to insert data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_index(self, collection_name):
|
||||
# Create IVF_FLAT index on milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
default_index = {
|
||||
"index_type": "IVF_SQ8",
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nlist": 16384
|
||||
}
|
||||
}
|
||||
status = self.collection.create_index(
|
||||
field_name="embedding", index_params=default_index)
|
||||
if not status.code:
|
||||
LOGGER.debug(
|
||||
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||
)
|
||||
return status
|
||||
else:
|
||||
raise Exception(status.message)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create index: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_collection(self, collection_name):
|
||||
# Delete Milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
self.collection.drop()
|
||||
LOGGER.debug("Successfully drop collection!")
|
||||
return "ok"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to drop collection: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_vectors(self, collection_name, vectors, top_k):
|
||||
# Search vector in milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
search_params = {
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nprobe": 16
|
||||
}
|
||||
}
|
||||
res = self.collection.search(
|
||||
vectors,
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k)
|
||||
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def count(self, collection_name):
|
||||
# Get the number of milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
num = self.collection.num_entities
|
||||
LOGGER.debug(
|
||||
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||
)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import pymysql
|
||||
from config import MYSQL_DB
|
||||
from config import MYSQL_HOST
|
||||
from config import MYSQL_PORT
|
||||
from config import MYSQL_PWD
|
||||
from config import MYSQL_USER
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
class MySQLHelper():
|
||||
"""
|
||||
the basic operations of PyMySQL
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to MySQL server
|
||||
# 2. create a table
|
||||
# 3. insert data to table
|
||||
# 4. search by milvus ids
|
||||
# 5. delete table
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def test_connection(self):
|
||||
try:
|
||||
self.conn.ping()
|
||||
except Exception:
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def create_mysql_table(self, table_name):
|
||||
# Create mysql table if not exists
|
||||
self.test_connection()
|
||||
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def load_data_to_mysql(self, table_name, data):
|
||||
# Batch insert (Milvus_ids, img_path) to mysql
|
||||
self.test_connection()
|
||||
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||
try:
|
||||
self.cursor.executemany(sql, data)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(
|
||||
f"MYSQL loads data to table: {table_name} successfully")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_by_milvus_ids(self, ids, table_name):
|
||||
# Get the img_path according to the milvus ids
|
||||
self.test_connection()
|
||||
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
results = [res[0] for res in results]
|
||||
LOGGER.debug("MYSQL search by milvus id.")
|
||||
return results
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_table(self, table_name):
|
||||
# Delete mysql table if exists
|
||||
self.test_connection()
|
||||
sql = "drop table if exists " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_all_data(self, table_name):
|
||||
# Delete all the data in mysql table
|
||||
self.test_connection()
|
||||
sql = 'delete from ' + table_name + ';'
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def count_table(self, table_name):
|
||||
# Get the number of mysql table
|
||||
self.test_connection()
|
||||
sql = "select count(milvus_id) from " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||
return results[0][0]
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_count(table_name, milvus_cli):
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return None
|
||||
num = milvus_cli.count(table_name)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to count table {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return "Collection is not exist"
|
||||
status = milvus_cli.delete_collection(table_name)
|
||||
mysql_cli.delete_table(table_name)
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from diskcache import Cache
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audios(path):
|
||||
"""
|
||||
List all wav and aif files recursively under the path folder.
|
||||
"""
|
||||
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||
return [
|
||||
item
|
||||
for sublist in [[os.path.join(dir, file) for file in files]
|
||||
for dir, _, files in list(os.walk(path))]
|
||||
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||
]
|
||||
|
||||
|
||||
def extract_features(audio_dir):
|
||||
"""
|
||||
Get the vector of audio
|
||||
"""
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
feats = []
|
||||
names = []
|
||||
audio_list = get_audios(audio_dir)
|
||||
total = len(audio_list)
|
||||
cache['total'] = total
|
||||
for i, audio_path in enumerate(audio_list):
|
||||
norm_feat = get_audio_embedding(audio_path)
|
||||
if norm_feat is None:
|
||||
continue
|
||||
feats.append(norm_feat)
|
||||
names.append(audio_path.encode())
|
||||
cache['current'] = i + 1
|
||||
print(
|
||||
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||
)
|
||||
return feats, names
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def format_data(ids, names):
|
||||
"""
|
||||
Combine the id of the vector and the name of the audio into a list
|
||||
"""
|
||||
data = []
|
||||
for i in range(len(ids)):
|
||||
value = (str(ids[i]), names[i])
|
||||
data.append(value)
|
||||
return data
|
||||
|
||||
|
||||
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Import vectors to Milvus and data to Mysql respectively
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
vectors, names = extract_features(audio_dir)
|
||||
ids = milvus_cli.insert(table_name, vectors)
|
||||
milvus_cli.create_index(table_name)
|
||||
mysql_cli.create_mysql_table(table_name)
|
||||
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||
return len(ids)
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from config import TOP_K
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
try:
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
feat = get_audio_embedding(audio_path)
|
||||
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||
vids = [str(x.id) for x in vectors[0]]
|
||||
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||
distances = [x.distance for x in vectors[0]]
|
||||
for i in range(len(paths)):
|
||||
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||
paths[i] = tmp
|
||||
distances[i] = (1 - distances[i]) * 100
|
||||
return vids, paths, distances
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with search: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
download audio data
|
||||
"""
|
||||
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||
gdown.download(url)
|
||||
|
||||
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||
zip_ref.extractall('./example_audio')
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
response = client.post("/audio/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_load():
|
||||
"""
|
||||
Insert all the audio files under the file path to Milvus/MySQL
|
||||
"""
|
||||
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully loaded data!"
|
||||
}
|
||||
|
||||
|
||||
def test_progress():
|
||||
"""
|
||||
Get the progress of dealing with data
|
||||
"""
|
||||
response = client.get("/progress")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "current: 20, total: 20"
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
response = client.get("audio/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == 20
|
||||
|
||||
|
||||
def test_search():
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
response = client.post(
|
||||
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_data():
|
||||
"""
|
||||
Get the audio file
|
||||
"""
|
||||
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
test_load()
|
||||
test_count()
|
||||
test_search()
|
||||
test_drop()
|
@ -0,0 +1 @@
|
||||
*.wav
|
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
|
||||
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav --topk 1
|
@ -1,5 +1,9 @@
|
||||
# Changelog
|
||||
|
||||
Date: 2022-3-15, Author: Xiaojie Chen.
|
||||
- kaldi and librosa mfcc, fbank, spectrogram.
|
||||
- unit test and benchmark.
|
||||
|
||||
Date: 2022-2-25, Author: Hui Zhang.
|
||||
- Refactor architecture.
|
||||
- dtw distance and mcd style dtw
|
||||
- dtw distance and mcd style dtw.
|
||||
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import unittest
|
||||
import urllib.request
|
||||
|
||||
mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav'
|
||||
|
||||
|
||||
class BackendTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initWavInput()
|
||||
|
||||
def initWavInput(self):
|
||||
self.files = []
|
||||
for url in [mono_channel_wav, multi_channels_wav]:
|
||||
if not os.path.isfile(os.path.basename(url)):
|
||||
urllib.request.urlretrieve(url, os.path.basename(url))
|
||||
self.files.append(os.path.basename(url))
|
||||
|
||||
def initParmas(self):
|
||||
raise NotImplementedError
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,73 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import filecmp
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
import paddleaudio
|
||||
from ..base import BackendTest
|
||||
|
||||
|
||||
class TestIO(BackendTest):
|
||||
def test_load_mono_channel(self):
|
||||
sf_data, sf_sr = sf.read(self.files[0])
|
||||
pa_data, pa_sr = paddleaudio.load(
|
||||
self.files[0], normal=False, dtype='float64')
|
||||
|
||||
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||
self.assertEqual(sf_sr, pa_sr)
|
||||
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||
|
||||
def test_load_multi_channels(self):
|
||||
sf_data, sf_sr = sf.read(self.files[1])
|
||||
sf_data = sf_data.T # Channel dim first
|
||||
pa_data, pa_sr = paddleaudio.load(
|
||||
self.files[1], mono=False, normal=False, dtype='float64')
|
||||
|
||||
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||
self.assertEqual(sf_sr, pa_sr)
|
||||
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||
|
||||
def test_save_mono_channel(self):
|
||||
waveform, sr = np.random.randint(
|
||||
low=-32768, high=32768, size=(48000), dtype=np.int16), 16000
|
||||
sf_tmp_file = 'sf_tmp.wav'
|
||||
pa_tmp_file = 'pa_tmp.wav'
|
||||
|
||||
sf.write(sf_tmp_file, waveform, sr)
|
||||
paddleaudio.save(waveform, sr, pa_tmp_file)
|
||||
|
||||
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||
for file in [sf_tmp_file, pa_tmp_file]:
|
||||
os.remove(file)
|
||||
|
||||
def test_save_multi_channels(self):
|
||||
waveform, sr = np.random.randint(
|
||||
low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000
|
||||
sf_tmp_file = 'sf_tmp.wav'
|
||||
pa_tmp_file = 'pa_tmp.wav'
|
||||
|
||||
sf.write(sf_tmp_file, waveform.T, sr)
|
||||
paddleaudio.save(waveform.T, sr, pa_tmp_file)
|
||||
|
||||
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||
for file in [sf_tmp_file, pa_tmp_file]:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,39 @@
|
||||
# 1. Prepare
|
||||
First, install `pytest-benchmark` via pip.
|
||||
```sh
|
||||
pip install pytest-benchmark
|
||||
```
|
||||
|
||||
# 2. Run
|
||||
Run the specific script for profiling.
|
||||
```sh
|
||||
pytest melspectrogram.py
|
||||
```
|
||||
|
||||
Result:
|
||||
```sh
|
||||
========================================================================== test session starts ==========================================================================
|
||||
platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0
|
||||
benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
|
||||
rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio
|
||||
plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0
|
||||
collected 4 items
|
||||
|
||||
melspectrogram.py .... [100%]
|
||||
|
||||
|
||||
-------------------------------------------------------------------------------------------------- benchmark: 4 tests -------------------------------------------------------------------------------------------------
|
||||
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1
|
||||
test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1
|
||||
test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1
|
||||
test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Legend:
|
||||
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
|
||||
OPS: Operations Per Second, computed as 1 / Mean
|
||||
========================================================================== 4 passed in 21.12s ===========================================================================
|
||||
|
||||
```
|
@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
log_mel_extractor = paddleaudio.features.LogMelSpectrogram(
|
||||
**mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def log_melspectrogram():
|
||||
return log_mel_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_log_melspect_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_log_melspect_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=2)
|
||||
|
||||
|
||||
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||
**mel_conf_torchaudio, f_min=0.0)
|
||||
amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0)
|
||||
|
||||
|
||||
def melspectrogram_torchaudio():
|
||||
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def log_melspectrogram_torchaudio():
|
||||
mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch)
|
||||
return amplitude_to_DB(mel_specgram).squeeze(0)
|
||||
|
||||
|
||||
def test_log_melspect_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
amplitude_to_DB = amplitude_to_DB.to('cpu')
|
||||
|
||||
feature_paddleaudio = benchmark(log_melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_log_melspect_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
amplitude_to_DB = amplitude_to_DB.to('cuda')
|
||||
|
||||
feature_torchaudio = benchmark(log_melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=2)
|
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
mel_extractor = paddleaudio.features.MelSpectrogram(
|
||||
**mel_conf, f_min=0.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def melspectrogram():
|
||||
return mel_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_melspect_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_melspect_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(melspectrogram)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||
**mel_conf_torchaudio, f_min=0.0)
|
||||
|
||||
|
||||
def melspectrogram_torchaudio():
|
||||
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def test_melspect_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
feature_paddleaudio = benchmark(melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_melspect_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
feature_torchaudio = benchmark(melspectrogram_torchaudio)
|
||||
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=3)
|
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
if not os.path.isfile(os.path.basename(wav_url)):
|
||||
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||
|
||||
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||
|
||||
# Feature conf
|
||||
mel_conf = {
|
||||
'sr': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
}
|
||||
mfcc_conf = {
|
||||
'n_mfcc': 20,
|
||||
'top_db': 80.0,
|
||||
}
|
||||
mfcc_conf.update(mel_conf)
|
||||
|
||||
mel_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_fft': 512,
|
||||
'hop_length': 128,
|
||||
'n_mels': 40,
|
||||
'norm': 'slaney',
|
||||
'mel_scale': 'slaney',
|
||||
}
|
||||
mfcc_conf_torchaudio = {
|
||||
'sample_rate': sr,
|
||||
'n_mfcc': 20,
|
||||
}
|
||||
|
||||
|
||||
def enable_cpu_device():
|
||||
paddle.set_device('cpu')
|
||||
|
||||
|
||||
def enable_gpu_device():
|
||||
paddle.set_device('gpu')
|
||||
|
||||
|
||||
mfcc_extractor = paddleaudio.features.MFCC(
|
||||
**mfcc_conf, f_min=0.0, dtype=waveform_tensor.dtype)
|
||||
|
||||
|
||||
def mfcc():
|
||||
return mfcc_extractor(waveform_tensor).squeeze(0)
|
||||
|
||||
|
||||
def test_mfcc_cpu(benchmark):
|
||||
enable_cpu_device()
|
||||
feature_paddleaudio = benchmark(mfcc)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_mfcc_gpu(benchmark):
|
||||
enable_gpu_device()
|
||||
feature_paddleaudio = benchmark(mfcc)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
del mel_conf_torchaudio['sample_rate']
|
||||
mfcc_extractor_torchaudio = torchaudio.transforms.MFCC(
|
||||
**mfcc_conf_torchaudio, melkwargs=mel_conf_torchaudio)
|
||||
|
||||
|
||||
def mfcc_torchaudio():
|
||||
return mfcc_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||
|
||||
|
||||
def test_mfcc_cpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mfcc_extractor_torchaudio
|
||||
|
||||
mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cpu')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||
|
||||
feature_paddleaudio = benchmark(mfcc_torchaudio)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddleaudio, decimal=3)
|
||||
|
||||
|
||||
def test_mfcc_gpu_torchaudio(benchmark):
|
||||
global waveform_tensor_torch, mfcc_extractor_torchaudio
|
||||
|
||||
mel_extractor_torchaudio = mfcc_extractor_torchaudio.to('cuda')
|
||||
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||
|
||||
feature_torchaudio = benchmark(mfcc_torchaudio)
|
||||
feature_librosa = librosa.feature.mfcc(waveform, **mel_conf)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_torchaudio.cpu(), decimal=3)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import unittest
|
||||
import urllib.request
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddleaudio import load
|
||||
|
||||
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||
|
||||
|
||||
class FeatTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.initParmas()
|
||||
self.initWavInput()
|
||||
self.setUpDevice()
|
||||
|
||||
def setUpDevice(self, device='cpu'):
|
||||
paddle.set_device(device)
|
||||
|
||||
def initWavInput(self, url=wav_url):
|
||||
if not os.path.isfile(os.path.basename(url)):
|
||||
urllib.request.urlretrieve(url, os.path.basename(url))
|
||||
self.waveform, self.sr = load(os.path.abspath(os.path.basename(url)))
|
||||
self.waveform = self.waveform.astype(
|
||||
np.float32
|
||||
) # paddlespeech.s2t.transform.spectrogram only supports float32
|
||||
dim = len(self.waveform.shape)
|
||||
|
||||
assert dim in [1, 2]
|
||||
if dim == 1:
|
||||
self.waveform = np.expand_dims(self.waveform, 0)
|
||||
|
||||
def initParmas(self):
|
||||
raise NotImplementedError
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
from paddlespeech.s2t.transform.spectrogram import IStft
|
||||
from paddlespeech.s2t.transform.spectrogram import Stft
|
||||
|
||||
|
||||
class TestIstft(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.window_str = 'hann'
|
||||
|
||||
def test_istft(self):
|
||||
ps_stft = Stft(self.n_fft, self.hop_length)
|
||||
ps_res = ps_stft(
|
||||
self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
|
||||
x = paddle.to_tensor(ps_res)
|
||||
|
||||
ps_istft = IStft(self.hop_length)
|
||||
ps_res = ps_istft(ps_res.T)
|
||||
|
||||
window = get_window(
|
||||
self.window_str, self.n_fft, dtype=self.waveform.dtype)
|
||||
pd_res = paddle.signal.istft(
|
||||
x, self.n_fft, self.hop_length, window=window)
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
|
||||
|
||||
class TestKaldi(FeatTest):
|
||||
def initParmas(self):
|
||||
self.window_size = 1024
|
||||
self.dtype = 'float32'
|
||||
|
||||
def test_window(self):
|
||||
t_hann_window = torch.hann_window(
|
||||
self.window_size, periodic=False, dtype=eval(f'torch.{self.dtype}'))
|
||||
t_hamm_window = torch.hamming_window(
|
||||
self.window_size,
|
||||
periodic=False,
|
||||
alpha=0.54,
|
||||
beta=0.46,
|
||||
dtype=eval(f'torch.{self.dtype}'))
|
||||
t_povey_window = torch.hann_window(
|
||||
self.window_size, periodic=False,
|
||||
dtype=eval(f'torch.{self.dtype}')).pow(0.85)
|
||||
|
||||
p_hann_window = paddleaudio.functional.window.get_window(
|
||||
'hann',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}'))
|
||||
p_hamm_window = paddleaudio.functional.window.get_window(
|
||||
'hamming',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}'))
|
||||
p_povey_window = paddleaudio.functional.window.get_window(
|
||||
'hann',
|
||||
self.window_size,
|
||||
fftbins=False,
|
||||
dtype=eval(f'paddle.{self.dtype}')).pow(0.85)
|
||||
|
||||
np.testing.assert_array_almost_equal(t_hann_window, p_hann_window)
|
||||
np.testing.assert_array_almost_equal(t_hamm_window, p_hamm_window)
|
||||
np.testing.assert_array_almost_equal(t_povey_window, p_povey_window)
|
||||
|
||||
def test_fbank(self):
|
||||
ta_features = torchaudio.compliance.kaldi.fbank(
|
||||
torch.from_numpy(self.waveform.astype(self.dtype)))
|
||||
pa_features = paddleaudio.compliance.kaldi.fbank(
|
||||
paddle.to_tensor(self.waveform.astype(self.dtype)))
|
||||
np.testing.assert_array_almost_equal(
|
||||
ta_features, pa_features, decimal=4)
|
||||
|
||||
def test_mfcc(self):
|
||||
ta_features = torchaudio.compliance.kaldi.mfcc(
|
||||
torch.from_numpy(self.waveform.astype(self.dtype)))
|
||||
pa_features = paddleaudio.compliance.kaldi.mfcc(
|
||||
paddle.to_tensor(self.waveform.astype(self.dtype)))
|
||||
np.testing.assert_array_almost_equal(
|
||||
ta_features, pa_features, decimal=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,281 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
|
||||
|
||||
class TestLibrosa(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.n_mels = 40
|
||||
self.n_mfcc = 20
|
||||
self.fmin = 0.0
|
||||
self.window_str = 'hann'
|
||||
self.pad_mode = 'reflect'
|
||||
self.top_db = 80.0
|
||||
|
||||
def test_stft(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
feature_librosa = librosa.core.stft(
|
||||
y=self.waveform,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
dtype=None,
|
||||
pad_mode=self.pad_mode, )
|
||||
x = paddle.to_tensor(self.waveform).unsqueeze(0)
|
||||
window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
|
||||
feature_paddle = paddle.signal.stft(
|
||||
x=x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=window,
|
||||
center=True,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=False,
|
||||
onesided=True, ).squeeze(0)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddle, decimal=5)
|
||||
|
||||
def test_istft(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# Get stft result from librosa.
|
||||
stft_matrix = librosa.core.stft(
|
||||
y=self.waveform,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
pad_mode=self.pad_mode, )
|
||||
|
||||
feature_librosa = librosa.core.istft(
|
||||
stft_matrix=stft_matrix,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=self.window_str,
|
||||
center=True,
|
||||
dtype=None,
|
||||
length=None, )
|
||||
|
||||
x = paddle.to_tensor(stft_matrix).unsqueeze(0)
|
||||
window = get_window(
|
||||
self.window_str,
|
||||
self.n_fft,
|
||||
dtype=paddle.to_tensor(self.waveform).dtype)
|
||||
feature_paddle = paddle.signal.istft(
|
||||
x=x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=None,
|
||||
window=window,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
length=None,
|
||||
return_complex=False, ).squeeze(0)
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_paddle, decimal=5)
|
||||
|
||||
def test_mel(self):
|
||||
feature_librosa = librosa.filters.mel(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
fmax=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=self.waveform.dtype, )
|
||||
feature_compliance = paddleaudio.compliance.librosa.compute_fbank_matrix(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
fmax=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=self.waveform.dtype, )
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
feature_functional = paddleaudio.functional.compute_fbank_matrix(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
f_max=None,
|
||||
htk=False,
|
||||
norm='slaney',
|
||||
dtype=x.dtype, )
|
||||
|
||||
np.testing.assert_array_almost_equal(feature_librosa,
|
||||
feature_compliance)
|
||||
np.testing.assert_array_almost_equal(feature_librosa,
|
||||
feature_functional)
|
||||
|
||||
def test_melspect(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.melspectrogram(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
to_db=False)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.MelSpectrogram(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=5)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=5)
|
||||
|
||||
def test_log_melspect(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.melspectrogram(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
feature_librosa = librosa.power_to_db(feature_librosa, top_db=None)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.melspectrogram(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.LogMelSpectrogram(
|
||||
sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=5)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=4)
|
||||
|
||||
def test_mfcc(self):
|
||||
if len(self.waveform.shape) == 2: # (C, T)
|
||||
self.waveform = self.waveform.squeeze(
|
||||
0) # 1D input for librosa.feature.melspectrogram
|
||||
|
||||
# librosa:
|
||||
feature_librosa = librosa.feature.mfcc(
|
||||
y=self.waveform,
|
||||
sr=self.sr,
|
||||
S=None,
|
||||
n_mfcc=self.n_mfcc,
|
||||
dct_type=2,
|
||||
norm='ortho',
|
||||
lifter=0,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin)
|
||||
|
||||
# paddleaudio.compliance.librosa:
|
||||
feature_compliance = paddleaudio.compliance.librosa.mfcc(
|
||||
x=self.waveform,
|
||||
sr=self.sr,
|
||||
n_mfcc=self.n_mfcc,
|
||||
dct_type=2,
|
||||
norm='ortho',
|
||||
lifter=0,
|
||||
window_size=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
top_db=self.top_db)
|
||||
|
||||
# paddleaudio.features.layer
|
||||
x = paddle.to_tensor(
|
||||
self.waveform, dtype=paddle.float64).unsqueeze(0) # Add batch dim.
|
||||
feature_extractor = paddleaudio.features.MFCC(
|
||||
sr=self.sr,
|
||||
n_mfcc=self.n_mfcc,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
f_min=self.fmin,
|
||||
top_db=self.top_db,
|
||||
dtype=x.dtype)
|
||||
feature_layer = feature_extractor(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_compliance, decimal=4)
|
||||
np.testing.assert_array_almost_equal(
|
||||
feature_librosa, feature_layer, decimal=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddlespeech.s2t.transform.spectrogram import LogMelSpectrogram
|
||||
|
||||
|
||||
class TestLogMelSpectrogram(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.n_mels = 40
|
||||
|
||||
def test_log_melspect(self):
|
||||
ps_melspect = LogMelSpectrogram(self.sr, self.n_mels, self.n_fft,
|
||||
self.hop_length)
|
||||
ps_res = ps_melspect(self.waveform.T).squeeze(1).T
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
# paddlespeech.s2t的特征存在幅度谱和功率谱滥用的情况
|
||||
ps_melspect = paddleaudio.features.LogMelSpectrogram(
|
||||
self.sr,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
power=1.0,
|
||||
n_mels=self.n_mels,
|
||||
f_min=0.0)
|
||||
pa_res = (ps_melspect(x) / 10.0).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import paddleaudio
|
||||
from .base import FeatTest
|
||||
from paddlespeech.s2t.transform.spectrogram import Spectrogram
|
||||
|
||||
|
||||
class TestSpectrogram(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
|
||||
def test_spectrogram(self):
|
||||
ps_spect = Spectrogram(self.n_fft, self.hop_length)
|
||||
ps_res = ps_spect(self.waveform.T).squeeze(1).T # Magnitude
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
pa_spect = paddleaudio.features.Spectrogram(
|
||||
self.n_fft, self.hop_length, power=1.0)
|
||||
pa_res = pa_spect(x).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pa_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from .base import FeatTest
|
||||
from paddleaudio.functional.window import get_window
|
||||
from paddlespeech.s2t.transform.spectrogram import Stft
|
||||
|
||||
|
||||
class TestStft(FeatTest):
|
||||
def initParmas(self):
|
||||
self.n_fft = 512
|
||||
self.hop_length = 128
|
||||
self.window_str = 'hann'
|
||||
|
||||
def test_stft(self):
|
||||
ps_stft = Stft(self.n_fft, self.hop_length)
|
||||
ps_res = ps_stft(
|
||||
self.waveform.T).squeeze(1).T # (n_fft//2 + 1, n_frmaes)
|
||||
|
||||
x = paddle.to_tensor(self.waveform)
|
||||
window = get_window(self.window_str, self.n_fft, dtype=x.dtype)
|
||||
pd_res = paddle.signal.stft(
|
||||
x, self.n_fft, self.hop_length, window=window).squeeze(0).numpy()
|
||||
|
||||
np.testing.assert_array_almost_equal(ps_res, pd_res, decimal=5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from paddlespeech.cli.cls.infer import CLSExecutor
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.utils import download_and_decompress
|
||||
from paddlespeech.cli.utils import MODEL_HOME
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
from paddlespeech.server.utils.paddle_predictor import init_predictor
|
||||
from paddlespeech.server.utils.paddle_predictor import run_model
|
||||
|
||||
__all__ = ['CLSEngine']
|
||||
|
||||
pretrained_models = {
|
||||
"panns_cnn6-32k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
|
||||
'md5':
|
||||
'da087c31046d23281d8ec5188c1967da',
|
||||
'cfg_path':
|
||||
'panns.yaml',
|
||||
'model_path':
|
||||
'inference.pdmodel',
|
||||
'params_path':
|
||||
'inference.pdiparams',
|
||||
'label_file':
|
||||
'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn10-32k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
|
||||
'md5':
|
||||
'5460cc6eafbfaf0f261cc75b90284ae1',
|
||||
'cfg_path':
|
||||
'panns.yaml',
|
||||
'model_path':
|
||||
'inference.pdmodel',
|
||||
'params_path':
|
||||
'inference.pdiparams',
|
||||
'label_file':
|
||||
'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn14-32k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
|
||||
'md5':
|
||||
'ccc80b194821274da79466862b2ab00f',
|
||||
'cfg_path':
|
||||
'panns.yaml',
|
||||
'model_path':
|
||||
'inference.pdmodel',
|
||||
'params_path':
|
||||
'inference.pdiparams',
|
||||
'label_file':
|
||||
'audioset_labels.txt',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CLSServerExecutor(CLSExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def _get_pretrained_path(self, tag: str) -> os.PathLike:
|
||||
"""
|
||||
Download and returns pretrained resources path of current task.
|
||||
"""
|
||||
support_models = list(pretrained_models.keys())
|
||||
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
|
||||
tag, '\n\t\t'.join(support_models))
|
||||
|
||||
res_path = os.path.join(MODEL_HOME, tag)
|
||||
decompressed_path = download_and_decompress(pretrained_models[tag],
|
||||
res_path)
|
||||
decompressed_path = os.path.abspath(decompressed_path)
|
||||
logger.info(
|
||||
'Use pretrained model stored in: {}'.format(decompressed_path))
|
||||
|
||||
return decompressed_path
|
||||
|
||||
def _init_from_path(
|
||||
self,
|
||||
model_type: str='panns_cnn14',
|
||||
cfg_path: Optional[os.PathLike]=None,
|
||||
model_path: Optional[os.PathLike]=None,
|
||||
params_path: Optional[os.PathLike]=None,
|
||||
label_file: Optional[os.PathLike]=None,
|
||||
predictor_conf: dict=None, ):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
|
||||
if cfg_path is None or model_path is None or params_path is None or label_file is None:
|
||||
tag = model_type + '-' + '32k'
|
||||
self.res_path = self._get_pretrained_path(tag)
|
||||
self.cfg_path = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['cfg_path'])
|
||||
self.model_path = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['model_path'])
|
||||
self.params_path = os.path.join(
|
||||
self.res_path, pretrained_models[tag]['params_path'])
|
||||
self.label_file = os.path.join(self.res_path,
|
||||
pretrained_models[tag]['label_file'])
|
||||
else:
|
||||
self.cfg_path = os.path.abspath(cfg_path)
|
||||
self.model_path = os.path.abspath(model_path)
|
||||
self.params_path = os.path.abspath(params_path)
|
||||
self.label_file = os.path.abspath(label_file)
|
||||
|
||||
logger.info(self.cfg_path)
|
||||
logger.info(self.model_path)
|
||||
logger.info(self.params_path)
|
||||
logger.info(self.label_file)
|
||||
|
||||
# config
|
||||
with open(self.cfg_path, 'r') as f:
|
||||
self._conf = yaml.safe_load(f)
|
||||
logger.info("Read cfg file successfully.")
|
||||
|
||||
# labels
|
||||
self._label_list = []
|
||||
with open(self.label_file, 'r') as f:
|
||||
for line in f:
|
||||
self._label_list.append(line.strip())
|
||||
logger.info("Read label file successfully.")
|
||||
|
||||
# Create predictor
|
||||
self.predictor_conf = predictor_conf
|
||||
self.predictor = init_predictor(
|
||||
model_file=self.model_path,
|
||||
params_file=self.params_path,
|
||||
predictor_conf=self.predictor_conf)
|
||||
logger.info("Create predictor successfully.")
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self):
|
||||
"""
|
||||
Model inference and result stored in self.output.
|
||||
"""
|
||||
output = run_model(self.predictor, [self._inputs['feats'].numpy()])
|
||||
self._outputs['logits'] = output[0]
|
||||
|
||||
|
||||
class CLSEngine(BaseEngine):
|
||||
"""CLS server engine
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CLSEngine, self).__init__()
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
"""init engine resource
|
||||
|
||||
Args:
|
||||
config_file (str): config file
|
||||
|
||||
Returns:
|
||||
bool: init failed or success
|
||||
"""
|
||||
self.executor = CLSServerExecutor()
|
||||
self.config = config
|
||||
self.executor._init_from_path(
|
||||
self.config.model_type, self.config.cfg_path,
|
||||
self.config.model_path, self.config.params_path,
|
||||
self.config.label_file, self.config.predictor_conf)
|
||||
|
||||
logger.info("Initialize CLS server engine successfully.")
|
||||
return True
|
||||
|
||||
def run(self, audio_data):
|
||||
"""engine run
|
||||
|
||||
Args:
|
||||
audio_data (bytes): base64.b64decode
|
||||
"""
|
||||
|
||||
self.executor.preprocess(io.BytesIO(audio_data))
|
||||
st = time.time()
|
||||
self.executor.infer()
|
||||
infer_time = time.time() - st
|
||||
|
||||
logger.info("inference time: {}".format(infer_time))
|
||||
logger.info("cls engine type: inference")
|
||||
|
||||
def postprocess(self, topk: int):
|
||||
"""postprocess
|
||||
"""
|
||||
assert topk <= len(self.executor._label_list
|
||||
), 'Value of topk is larger than number of labels.'
|
||||
|
||||
result = np.squeeze(self.executor._outputs['logits'], axis=0)
|
||||
topk_idx = (-result).argsort()[:topk]
|
||||
topk_results = []
|
||||
for idx in topk_idx:
|
||||
res = {}
|
||||
label, score = self.executor._label_list[idx], result[idx]
|
||||
res['class_name'] = label
|
||||
res['prob'] = score
|
||||
topk_results.append(res)
|
||||
|
||||
return topk_results
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import paddle
|
||||
|
||||
from paddlespeech.cli.cls.infer import CLSExecutor
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
|
||||
__all__ = ['CLSEngine']
|
||||
|
||||
|
||||
class CLSServerExecutor(CLSExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def get_topk_results(self, topk: int) -> List:
|
||||
assert topk <= len(
|
||||
self._label_list), 'Value of topk is larger than number of labels.'
|
||||
|
||||
result = self._outputs['logits'].squeeze(0).numpy()
|
||||
topk_idx = (-result).argsort()[:topk]
|
||||
res = {}
|
||||
topk_results = []
|
||||
for idx in topk_idx:
|
||||
label, score = self._label_list[idx], result[idx]
|
||||
res['class'] = label
|
||||
res['prob'] = score
|
||||
topk_results.append(res)
|
||||
return topk_results
|
||||
|
||||
|
||||
class CLSEngine(BaseEngine):
|
||||
"""CLS server engine
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CLSEngine, self).__init__()
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
"""init engine resource
|
||||
|
||||
Args:
|
||||
config_file (str): config file
|
||||
|
||||
Returns:
|
||||
bool: init failed or success
|
||||
"""
|
||||
self.input = None
|
||||
self.output = None
|
||||
self.executor = CLSServerExecutor()
|
||||
self.config = config
|
||||
try:
|
||||
if self.config.device:
|
||||
self.device = self.config.device
|
||||
else:
|
||||
self.device = paddle.get_device()
|
||||
paddle.set_device(self.device)
|
||||
except BaseException:
|
||||
logger.error(
|
||||
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
||||
)
|
||||
|
||||
try:
|
||||
self.executor._init_from_path(
|
||||
self.config.model, self.config.cfg_path, self.config.ckpt_path,
|
||||
self.config.label_file)
|
||||
except BaseException:
|
||||
logger.error("Initialize CLS server engine Failed.")
|
||||
return False
|
||||
|
||||
logger.info("Initialize CLS server engine successfully on device: %s." %
|
||||
(self.device))
|
||||
return True
|
||||
|
||||
def run(self, audio_data):
|
||||
"""engine run
|
||||
|
||||
Args:
|
||||
audio_data (bytes): base64.b64decode
|
||||
"""
|
||||
self.executor.preprocess(io.BytesIO(audio_data))
|
||||
st = time.time()
|
||||
self.executor.infer()
|
||||
infer_time = time.time() - st
|
||||
|
||||
logger.info("inference time: {}".format(infer_time))
|
||||
logger.info("cls engine type: python")
|
||||
|
||||
def postprocess(self, topk: int):
|
||||
"""postprocess
|
||||
"""
|
||||
assert topk <= len(self.executor._label_list
|
||||
), 'Value of topk is larger than number of labels.'
|
||||
|
||||
result = self.executor._outputs['logits'].squeeze(0).numpy()
|
||||
topk_idx = (-result).argsort()[:topk]
|
||||
topk_results = []
|
||||
for idx in topk_idx:
|
||||
res = {}
|
||||
label, score = self.executor._label_list[idx], result[idx]
|
||||
res['class_name'] = label
|
||||
res['prob'] = score
|
||||
topk_results.append(res)
|
||||
|
||||
return topk_results
|
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||
from paddlespeech.server.restful.request import CLSRequest
|
||||
from paddlespeech.server.restful.response import CLSResponse
|
||||
from paddlespeech.server.restful.response import ErrorResponse
|
||||
from paddlespeech.server.utils.errors import ErrorCode
|
||||
from paddlespeech.server.utils.errors import failed_response
|
||||
from paddlespeech.server.utils.exception import ServerBaseException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/paddlespeech/cls/help')
|
||||
def help():
|
||||
"""help
|
||||
|
||||
Returns:
|
||||
json: [description]
|
||||
"""
|
||||
response = {
|
||||
"success": "True",
|
||||
"code": 200,
|
||||
"message": {
|
||||
"global": "success"
|
||||
},
|
||||
"result": {
|
||||
"description": "cls server",
|
||||
"input": "base64 string of wavfile",
|
||||
"output": "classification result"
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/paddlespeech/cls", response_model=Union[CLSResponse, ErrorResponse])
|
||||
def cls(request_body: CLSRequest):
|
||||
"""cls api
|
||||
|
||||
Args:
|
||||
request_body (CLSRequest): [description]
|
||||
|
||||
Returns:
|
||||
json: [description]
|
||||
"""
|
||||
try:
|
||||
audio_data = base64.b64decode(request_body.audio)
|
||||
|
||||
# get single engine from engine pool
|
||||
engine_pool = get_engine_pool()
|
||||
cls_engine = engine_pool['cls']
|
||||
|
||||
cls_engine.run(audio_data)
|
||||
cls_results = cls_engine.postprocess(request_body.topk)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"code": 200,
|
||||
"message": {
|
||||
"description": "success"
|
||||
},
|
||||
"result": {
|
||||
"topk": request_body.topk,
|
||||
"results": cls_results
|
||||
}
|
||||
}
|
||||
|
||||
except ServerBaseException as e:
|
||||
response = failed_response(e.error_code, e.msg)
|
||||
except BaseException:
|
||||
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
|
||||
traceback.print_exc()
|
||||
|
||||
return response
|
@ -0,0 +1,100 @@
|
||||
009901 昨日,这名伤者与医生全部被警方依法刑事拘留。
|
||||
009902 钱伟长想到上海来办学校是经过深思熟虑的。
|
||||
009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。
|
||||
009904 李述德在离开之前,只说了一句柱驼杀父亲了。
|
||||
009905 这种车票和保险单捆绑出售属于重复性购买。
|
||||
009906 戴佩妮的男友西米露接唱情歌,让她非常开心。
|
||||
009907 观大势,谋大局,出大策始终是该院的办院方针。
|
||||
009908 他们骑着摩托回家,正好为农忙时的父母帮忙。
|
||||
009909 但是因为还没到退休年龄,只能掰着指头捱日子。
|
||||
009910 这几天雨水不断,人们恨不得待在家里不出门。
|
||||
009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。
|
||||
009912 藤村此番发言可能是为了凸显野田的领导能力。
|
||||
009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。
|
||||
009914 南海海域综合补给基地码头项目正在论证中。
|
||||
009915 也就是说今晚成都市民极有可能再次看到飘雪。
|
||||
009916 随着天气转热,各地的游泳场所开始人头攒动。
|
||||
009917 更让徐先生纳闷的是,房客的手机也打不通了。
|
||||
009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。
|
||||
009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。
|
||||
009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。
|
||||
009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客?
|
||||
009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。
|
||||
009923 会有很丰富的东西留下来,说都说不完。
|
||||
009924 这句话像从天而降,吓得四周一片寂静。
|
||||
009925 记者所在的是受害人家属所在的右区。
|
||||
009926 不管哈大爷去哪,它都一步不离地跟着。
|
||||
009927 大家抬头望去,一只老鼠正趴在吊顶上。
|
||||
009928 我决定过年就辞职,接手我爸的废品站!
|
||||
009929 最终,中国男子乒乓球队获得此奖项。
|
||||
009930 防汛抗旱两手抓,抗旱相对抓的不够。
|
||||
009931 图们江下游地区开发开放的进展如何?
|
||||
009932 这要求中国必须有一个坚强的政党领导。
|
||||
009933 再说,关于利益上的事俺俩都不好开口。
|
||||
009934 明代瓦剌,鞑靼入侵明境也是通过此地。
|
||||
009935 咪咪舔着孩子,把它身上的毛舔干净。
|
||||
009936 是否这次的国标修订被大企业绑架了?
|
||||
009937 判决后,姚某妻子胡某不服,提起上诉。
|
||||
009938 由此可以看出邯钢的经济效益来自何处。
|
||||
009939 琳达说,是瑜伽改变了她和马儿的生活。
|
||||
009940 楼下的保安告诉记者,这里不租也不卖。
|
||||
009941 习近平说,中斯两国人民传统友谊深厚。
|
||||
009942 传闻越来越多,后来连老汉儿自己都怕了。
|
||||
009943 我怒吼一声冲上去,举起砖头砸了过去。
|
||||
009944 我现在还不会,这就回去问问发明我的人。
|
||||
009945 显然,洛阳性奴案不具备上述两个前提。
|
||||
009946 另外,杰克逊有文唇线,眼线,眉毛的动作。
|
||||
009947 昨晚,华西都市报记者电话采访了尹琪。
|
||||
009948 涅拉季科未透露这些航空公司的名称。
|
||||
009949 从运行轨迹上来说,它也不可能是星星。
|
||||
009950 目前看,如果继续加息也存在两难问题。
|
||||
009951 曾宝仪在节目录制现场大爆观众糗事。
|
||||
009952 但任凭周某怎么叫,男子仍酣睡不醒。
|
||||
009953 老大爷说,小子,你挡我财路了,知道不?
|
||||
009954 没料到,闯下大头佛的阿伟还不知悔改。
|
||||
009955 卡扎菲部落式统治已遭遇部落内讧。
|
||||
009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。
|
||||
009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。
|
||||
009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。
|
||||
009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。
|
||||
009960 因为,我们所发出的力量必会因难度加大而减弱。
|
||||
009961 发生事故的楼梯拐角处仍可看到血迹。
|
||||
009962 想过进公安,可能身高不够,老汉儿也不让我进去。
|
||||
009963 路上关卡很多,为了方便撤离,只好轻装前进。
|
||||
009964 原来比尔盖茨就是美国微软公司联合创始人呀。
|
||||
009965 之后他们一家三口将与双方父母往峇里岛旅游。
|
||||
009966 谢谢总理,也感谢广大网友的参与,我们明年再见。
|
||||
009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。
|
||||
009968 我会打开邮件,你可以从那里继续。
|
||||
009969 美方对近期东海局势表示关切。
|
||||
009970 据悉,奥巴马一家人对这座冬季白宫极为满意。
|
||||
009971 打扫完你会很有成就感的,试一试,你就信了。
|
||||
009972 诺曼站在滑板车上,各就各位,准备出发啦!
|
||||
009973 塔河的寒夜,气温降到了零下三十多摄氏度。
|
||||
009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。
|
||||
009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。
|
||||
009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。
|
||||
009977 直至公元前一万一千年,它又再次出现。
|
||||
009978 尽量少玩电脑,少看电视,少打游戏。
|
||||
009979 从五到七,前后也就是六个月的时间。
|
||||
009980 一进咖啡店,他就遇见一张熟悉的脸。
|
||||
009981 好在众弟兄看到了把她追了回来。
|
||||
009982 有一个人说,哥们儿我们跑过它才能活。
|
||||
009983 捅了她以后,模糊记得她没咋动了。
|
||||
009984 从小到大,葛启义没有收到过压岁钱。
|
||||
009985 舞台下的你会对舞台上的你说什么?
|
||||
009986 但考生普遍认为,试题的怪多过难。
|
||||
009987 我希望每个人都能够尊重我们的隐私。
|
||||
009988 漫天的红霞使劲给两人增添气氛。
|
||||
009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声!
|
||||
009990 该车将三人撞倒后,在大雾中逃窜。
|
||||
009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。
|
||||
009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊!
|
||||
009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。
|
||||
009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。
|
||||
009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。
|
||||
009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。
|
||||
009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递?
|
||||
009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁?
|
||||
009999 因华国锋肖鸡,墓地设计根据其属相设计。
|
||||
010000 在狱中,张明宝悔恨交加,写了一份忏悔书。
|
@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import jit
|
||||
from paddle.static import InputSpec
|
||||
|
||||
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.frontend import English
|
||||
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||
from paddlespeech.t2s.modules.normalizer import ZScore
|
||||
|
||||
model_alias = {
|
||||
# acoustic model
|
||||
"speedyspeech":
|
||||
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
|
||||
"speedyspeech_inference":
|
||||
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
|
||||
"fastspeech2":
|
||||
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
|
||||
"fastspeech2_inference":
|
||||
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
|
||||
"tacotron2":
|
||||
"paddlespeech.t2s.models.tacotron2:Tacotron2",
|
||||
"tacotron2_inference":
|
||||
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
|
||||
# voc
|
||||
"pwgan":
|
||||
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
|
||||
"pwgan_inference":
|
||||
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
|
||||
"mb_melgan":
|
||||
"paddlespeech.t2s.models.melgan:MelGANGenerator",
|
||||
"mb_melgan_inference":
|
||||
"paddlespeech.t2s.models.melgan:MelGANInference",
|
||||
"style_melgan":
|
||||
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
|
||||
"style_melgan_inference":
|
||||
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
|
||||
"hifigan":
|
||||
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
|
||||
"hifigan_inference":
|
||||
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
|
||||
"wavernn":
|
||||
"paddlespeech.t2s.models.wavernn:WaveRNN",
|
||||
"wavernn_inference":
|
||||
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
|
||||
}
|
||||
|
||||
|
||||
# input
|
||||
def get_sentences(args):
|
||||
# construct dataset for evaluation
|
||||
sentences = []
|
||||
with open(args.text, 'rt') as f:
|
||||
for line in f:
|
||||
items = line.strip().split()
|
||||
utt_id = items[0]
|
||||
if 'lang' in args and args.lang == 'zh':
|
||||
sentence = "".join(items[1:])
|
||||
elif 'lang' in args and args.lang == 'en':
|
||||
sentence = " ".join(items[1:])
|
||||
sentences.append((utt_id, sentence))
|
||||
return sentences
|
||||
|
||||
|
||||
def get_test_dataset(args, test_metadata, am_name, am_dataset):
|
||||
if am_name == 'fastspeech2':
|
||||
fields = ["utt_id", "text"]
|
||||
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
|
||||
print("multiple speaker fastspeech2!")
|
||||
fields += ["spk_id"]
|
||||
elif 'voice_cloning' in args and args.voice_cloning:
|
||||
print("voice cloning!")
|
||||
fields += ["spk_emb"]
|
||||
else:
|
||||
print("single speaker fastspeech2!")
|
||||
elif am_name == 'speedyspeech':
|
||||
fields = ["utt_id", "phones", "tones"]
|
||||
elif am_name == 'tacotron2':
|
||||
fields = ["utt_id", "text"]
|
||||
if 'voice_cloning' in args and args.voice_cloning:
|
||||
print("voice cloning!")
|
||||
fields += ["spk_emb"]
|
||||
|
||||
test_dataset = DataTable(data=test_metadata, fields=fields)
|
||||
return test_dataset
|
||||
|
||||
|
||||
# frontend
|
||||
def get_frontend(args):
|
||||
if 'lang' in args and args.lang == 'zh':
|
||||
frontend = Frontend(
|
||||
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
|
||||
elif 'lang' in args and args.lang == 'en':
|
||||
frontend = English(phone_vocab_path=args.phones_dict)
|
||||
else:
|
||||
print("wrong lang!")
|
||||
print("frontend done!")
|
||||
return frontend
|
||||
|
||||
|
||||
# dygraph
|
||||
def get_am_inference(args, am_config):
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
tone_size = None
|
||||
if 'tones_dict' in args and args.tones_dict:
|
||||
with open(args.tones_dict, "r") as f:
|
||||
tone_id = [line.strip().split() for line in f.readlines()]
|
||||
tone_size = len(tone_id)
|
||||
print("tone_size:", tone_size)
|
||||
|
||||
spk_num = None
|
||||
if 'speaker_dict' in args and args.speaker_dict:
|
||||
with open(args.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
spk_num = len(spk_id)
|
||||
print("spk_num:", spk_num)
|
||||
|
||||
odim = am_config.n_mels
|
||||
# model: {model_name}_{dataset}
|
||||
am_name = args.am[:args.am.rindex('_')]
|
||||
am_dataset = args.am[args.am.rindex('_') + 1:]
|
||||
|
||||
am_class = dynamic_import(am_name, model_alias)
|
||||
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
|
||||
|
||||
if am_name == 'fastspeech2':
|
||||
am = am_class(
|
||||
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
|
||||
elif am_name == 'speedyspeech':
|
||||
am = am_class(
|
||||
vocab_size=vocab_size,
|
||||
tone_size=tone_size,
|
||||
spk_num=spk_num,
|
||||
**am_config["model"])
|
||||
elif am_name == 'tacotron2':
|
||||
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
|
||||
|
||||
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
|
||||
am.eval()
|
||||
am_mu, am_std = np.load(args.am_stat)
|
||||
am_mu = paddle.to_tensor(am_mu)
|
||||
am_std = paddle.to_tensor(am_std)
|
||||
am_normalizer = ZScore(am_mu, am_std)
|
||||
am_inference = am_inference_class(am_normalizer, am)
|
||||
am_inference.eval()
|
||||
print("acoustic model done!")
|
||||
return am_inference, am_name, am_dataset
|
||||
|
||||
|
||||
def get_voc_inference(args, voc_config):
|
||||
# model: {model_name}_{dataset}
|
||||
voc_name = args.voc[:args.voc.rindex('_')]
|
||||
voc_class = dynamic_import(voc_name, model_alias)
|
||||
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
|
||||
if voc_name != 'wavernn':
|
||||
voc = voc_class(**voc_config["generator_params"])
|
||||
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
|
||||
voc.remove_weight_norm()
|
||||
voc.eval()
|
||||
else:
|
||||
voc = voc_class(**voc_config["model"])
|
||||
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
|
||||
voc.eval()
|
||||
|
||||
voc_mu, voc_std = np.load(args.voc_stat)
|
||||
voc_mu = paddle.to_tensor(voc_mu)
|
||||
voc_std = paddle.to_tensor(voc_std)
|
||||
voc_normalizer = ZScore(voc_mu, voc_std)
|
||||
voc_inference = voc_inference_class(voc_normalizer, voc)
|
||||
voc_inference.eval()
|
||||
print("voc done!")
|
||||
return voc_inference
|
||||
|
||||
|
||||
# to static
|
||||
def am_to_static(args, am_inference, am_name, am_dataset):
|
||||
if am_name == 'fastspeech2':
|
||||
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
|
||||
am_inference = jit.to_static(
|
||||
am_inference,
|
||||
input_spec=[
|
||||
InputSpec([-1], dtype=paddle.int64),
|
||||
InputSpec([1], dtype=paddle.int64),
|
||||
], )
|
||||
else:
|
||||
am_inference = jit.to_static(
|
||||
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
|
||||
|
||||
elif am_name == 'speedyspeech':
|
||||
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
|
||||
am_inference = jit.to_static(
|
||||
am_inference,
|
||||
input_spec=[
|
||||
InputSpec([-1], dtype=paddle.int64), # text
|
||||
InputSpec([-1], dtype=paddle.int64), # tone
|
||||
InputSpec([1], dtype=paddle.int64), # spk_id
|
||||
None # duration
|
||||
])
|
||||
else:
|
||||
am_inference = jit.to_static(
|
||||
am_inference,
|
||||
input_spec=[
|
||||
InputSpec([-1], dtype=paddle.int64),
|
||||
InputSpec([-1], dtype=paddle.int64)
|
||||
])
|
||||
|
||||
elif am_name == 'tacotron2':
|
||||
am_inference = jit.to_static(
|
||||
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
|
||||
|
||||
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
|
||||
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am))
|
||||
return am_inference
|
||||
|
||||
|
||||
def voc_to_static(args, voc_inference):
|
||||
voc_inference = jit.to_static(
|
||||
voc_inference, input_spec=[
|
||||
InputSpec([-1, 80], dtype=paddle.float32),
|
||||
])
|
||||
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc))
|
||||
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc))
|
||||
return voc_inference
|
@ -0,0 +1 @@
|
||||
tools/valgrind*
|
@ -0,0 +1,61 @@
|
||||
# SpeechX -- All in One Speech Task Inference
|
||||
|
||||
## Environment
|
||||
|
||||
We develop under:
|
||||
* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7
|
||||
* os - Ubuntu 16.04.7 LTS
|
||||
* gcc/g++ - 8.2.0
|
||||
* cmake - 3.16.0
|
||||
|
||||
> We make sure all things work fun under docker, and recommend using it to develop and deploy.
|
||||
|
||||
* [How to Install Docker](https://docs.docker.com/engine/install/)
|
||||
* [A Docker Tutorial for Beginners](https://docker-curriculum.com/)
|
||||
* [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html)
|
||||
|
||||
## Build
|
||||
|
||||
1. First to launch docker container.
|
||||
|
||||
```
|
||||
nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 /bin/bash
|
||||
```
|
||||
|
||||
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
|
||||
|
||||
* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`.
|
||||
|
||||
|
||||
2. Build `speechx` and `examples`.
|
||||
|
||||
```
|
||||
pushd /path/to/speechx
|
||||
./build.sh
|
||||
```
|
||||
|
||||
3. Go to `examples` to have a fun.
|
||||
|
||||
More details please see `README.md` under `examples`.
|
||||
|
||||
|
||||
## Valgrind (Optional)
|
||||
|
||||
> If using docker please check `--privileged` is set when `docker run`.
|
||||
|
||||
* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up`
|
||||
```
|
||||
apt-get install libc6-dbg
|
||||
```
|
||||
|
||||
* Install
|
||||
|
||||
```
|
||||
pushd tools
|
||||
./setup_valgrind.sh
|
||||
popd
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result.
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue