commit
9497c93fb0
@ -0,0 +1,88 @@
|
|||||||
|
version: '3.5'
|
||||||
|
|
||||||
|
services:
|
||||||
|
etcd:
|
||||||
|
container_name: milvus-etcd
|
||||||
|
image: quay.io/coreos/etcd:v3.5.0
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
environment:
|
||||||
|
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||||
|
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||||
|
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||||
|
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||||
|
|
||||||
|
minio:
|
||||||
|
container_name: milvus-minio
|
||||||
|
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
environment:
|
||||||
|
MINIO_ACCESS_KEY: minioadmin
|
||||||
|
MINIO_SECRET_KEY: minioadmin
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||||
|
command: minio server /minio_data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
standalone:
|
||||||
|
container_name: milvus-standalone
|
||||||
|
image: milvusdb/milvus:v2.0.1
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.10
|
||||||
|
command: ["milvus", "run", "standalone"]
|
||||||
|
environment:
|
||||||
|
ETCD_ENDPOINTS: etcd:2379
|
||||||
|
MINIO_ADDRESS: minio:9000
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||||
|
ports:
|
||||||
|
- "19530:19530"
|
||||||
|
depends_on:
|
||||||
|
- "etcd"
|
||||||
|
- "minio"
|
||||||
|
|
||||||
|
mysql:
|
||||||
|
container_name: audio-mysql
|
||||||
|
image: mysql:5.7
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.11
|
||||||
|
environment:
|
||||||
|
- MYSQL_ROOT_PASSWORD=123456
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
|
||||||
|
webclient:
|
||||||
|
container_name: audio-webclient
|
||||||
|
image: qingen1/paddlespeech-audio-search-client:2.3
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
ipv4_address: 172.16.23.13
|
||||||
|
environment:
|
||||||
|
API_URL: 'http://127.0.0.1:8002'
|
||||||
|
ports:
|
||||||
|
- "8068:80"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost/"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
networks:
|
||||||
|
app_net:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
driver: default
|
||||||
|
config:
|
||||||
|
- subnet: 172.16.23.0/24
|
||||||
|
gateway: 172.16.23.1
|
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 80 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 84 KiB |
@ -0,0 +1,12 @@
|
|||||||
|
soundfile==0.10.3.post1
|
||||||
|
librosa==0.8.0
|
||||||
|
numpy
|
||||||
|
pymysql
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
diskcache==5.2.1
|
||||||
|
pymilvus==2.0.1
|
||||||
|
python-multipart
|
||||||
|
typing
|
||||||
|
starlette
|
||||||
|
pydantic
|
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
############### Milvus Configuration ###############
|
||||||
|
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||||
|
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||||
|
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||||
|
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||||
|
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||||
|
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||||
|
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||||
|
|
||||||
|
############### MySQL Configuration ###############
|
||||||
|
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||||
|
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||||
|
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||||
|
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||||
|
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||||
|
|
||||||
|
############### Data Path ###############
|
||||||
|
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||||
|
|
||||||
|
############### Number of Log Files ###############
|
||||||
|
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_embedding(path):
|
||||||
|
"""
|
||||||
|
Use vpr_inference to generate embedding of audio
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
RESAMPLE_RATE = 16000
|
||||||
|
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||||
|
|
||||||
|
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||||
|
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||||
|
# embedding = vpr.inference(audio)
|
||||||
|
np.random.seed(hash(os.path.basename(path)) % 1000000)
|
||||||
|
embedding = np.random.rand(1, 2048)
|
||||||
|
embedding = embedding / np.linalg.norm(embedding)
|
||||||
|
embedding = embedding.tolist()[0]
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with embedding:{e}")
|
||||||
|
return None
|
@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from config import UPLOAD_PATH
|
||||||
|
from diskcache import Cache
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi import File
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from logs import LOGGER
|
||||||
|
from milvus_helpers import MilvusHelper
|
||||||
|
from mysql_helpers import MySQLHelper
|
||||||
|
from operations.count import do_count
|
||||||
|
from operations.drop import do_drop
|
||||||
|
from operations.load import do_load
|
||||||
|
from operations.search import do_search
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"])
|
||||||
|
|
||||||
|
MODEL = None
|
||||||
|
MILVUS_CLI = MilvusHelper()
|
||||||
|
MYSQL_CLI = MySQLHelper()
|
||||||
|
|
||||||
|
# Mkdir 'tmp/audio-data'
|
||||||
|
if not os.path.exists(UPLOAD_PATH):
|
||||||
|
os.makedirs(UPLOAD_PATH)
|
||||||
|
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/data')
|
||||||
|
def audio_path(audio_path):
|
||||||
|
# Get the audio file
|
||||||
|
try:
|
||||||
|
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||||
|
return FileResponse(audio_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"upload audio error: {e}")
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/progress')
|
||||||
|
def get_progress():
|
||||||
|
# Get the progress of dealing with data
|
||||||
|
try:
|
||||||
|
cache = Cache('./tmp')
|
||||||
|
return f"current: {cache['current']}, total: {cache['total']}"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Upload data error: {e}")
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
class Item(BaseModel):
|
||||||
|
Table: Optional[str] = None
|
||||||
|
File: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/load')
|
||||||
|
async def load_audios(item: Item):
|
||||||
|
# Insert all the audio files under the file path to Milvus/MySQL
|
||||||
|
try:
|
||||||
|
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||||
|
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||||
|
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/search')
|
||||||
|
async def search_audio(request: Request,
|
||||||
|
table_name: str=None,
|
||||||
|
audio: UploadFile=File(...)):
|
||||||
|
# Search the uploaded audio in Milvus/MySQL
|
||||||
|
try:
|
||||||
|
# Save the upload data to server.
|
||||||
|
content = await audio.read()
|
||||||
|
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||||
|
with open(query_audio_path, "wb+") as f:
|
||||||
|
f.write(content)
|
||||||
|
host = request.headers['host']
|
||||||
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||||
|
MILVUS_CLI, MYSQL_CLI)
|
||||||
|
names = []
|
||||||
|
for path, score in zip(paths, distances):
|
||||||
|
names.append(os.path.basename(path))
|
||||||
|
LOGGER.info(f"search result {path}, score {score}")
|
||||||
|
res = dict(zip(paths, zip(names, distances)))
|
||||||
|
# Sort results by distance metric, closest distances first
|
||||||
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||||
|
LOGGER.info("Successfully searched similar audio!")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/search/local')
|
||||||
|
async def search_local_audio(request: Request,
|
||||||
|
query_audio_path: str,
|
||||||
|
table_name: str=None):
|
||||||
|
# Search the uploaded audio in Milvus/MySQL
|
||||||
|
try:
|
||||||
|
host = request.headers['host']
|
||||||
|
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||||
|
MILVUS_CLI, MYSQL_CLI)
|
||||||
|
names = []
|
||||||
|
for path, score in zip(paths, distances):
|
||||||
|
names.append(os.path.basename(path))
|
||||||
|
LOGGER.info(f"search result {path}, score {score}")
|
||||||
|
res = dict(zip(paths, zip(names, distances)))
|
||||||
|
# Sort results by distance metric, closest distances first
|
||||||
|
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||||
|
LOGGER.info("Successfully searched similar audio!")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.get('/audio/count')
|
||||||
|
async def count_audio(table_name: str=None):
|
||||||
|
# Returns the total number of vectors in the system
|
||||||
|
try:
|
||||||
|
num = do_count(table_name, MILVUS_CLI)
|
||||||
|
LOGGER.info("Successfully count the number of data!")
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.post('/audio/drop')
|
||||||
|
async def drop_tables(table_name: str=None):
|
||||||
|
# Delete the collection of Milvus and MySQL
|
||||||
|
try:
|
||||||
|
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||||
|
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(e)
|
||||||
|
return {'status': False, 'msg': e}, 400
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,185 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import METRIC_TYPE
|
||||||
|
from config import MILVUS_HOST
|
||||||
|
from config import MILVUS_PORT
|
||||||
|
from config import VECTOR_DIMENSION
|
||||||
|
from logs import LOGGER
|
||||||
|
from pymilvus import Collection
|
||||||
|
from pymilvus import CollectionSchema
|
||||||
|
from pymilvus import connections
|
||||||
|
from pymilvus import DataType
|
||||||
|
from pymilvus import FieldSchema
|
||||||
|
from pymilvus import utility
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusHelper:
|
||||||
|
"""
|
||||||
|
the basic operations of PyMilvus
|
||||||
|
|
||||||
|
# This example shows how to:
|
||||||
|
# 1. connect to Milvus server
|
||||||
|
# 2. create a collection
|
||||||
|
# 3. insert entities
|
||||||
|
# 4. create index
|
||||||
|
# 5. search
|
||||||
|
# 6. delete a collection
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
self.collection = None
|
||||||
|
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def set_collection(self, collection_name):
|
||||||
|
try:
|
||||||
|
if self.has_collection(collection_name):
|
||||||
|
self.collection = Collection(name=collection_name)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"There is no collection named:{collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to set collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def has_collection(self, collection_name):
|
||||||
|
# Return if Milvus has the collection
|
||||||
|
try:
|
||||||
|
return utility.has_collection(collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def create_collection(self, collection_name):
|
||||||
|
# Create milvus collection if not exists
|
||||||
|
try:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
field1 = FieldSchema(
|
||||||
|
name="id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
descrition="int64",
|
||||||
|
is_primary=True,
|
||||||
|
auto_id=True)
|
||||||
|
field2 = FieldSchema(
|
||||||
|
name="embedding",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
descrition="speaker embeddings",
|
||||||
|
dim=VECTOR_DIMENSION,
|
||||||
|
is_primary=False)
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=[field1, field2], description="embeddings info")
|
||||||
|
self.collection = Collection(
|
||||||
|
name=collection_name, schema=schema)
|
||||||
|
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||||
|
else:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
return "OK"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to create collection in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def insert(self, collection_name, vectors):
|
||||||
|
# Batch insert vectors to milvus collection
|
||||||
|
try:
|
||||||
|
self.create_collection(collection_name)
|
||||||
|
data = [vectors]
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
mr = self.collection.insert(data)
|
||||||
|
ids = mr.primary_keys
|
||||||
|
self.collection.load()
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to insert data to Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def create_index(self, collection_name):
|
||||||
|
# Create IVF_FLAT index on milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
default_index = {
|
||||||
|
"index_type": "IVF_SQ8",
|
||||||
|
"metric_type": METRIC_TYPE,
|
||||||
|
"params": {
|
||||||
|
"nlist": 16384
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status = self.collection.create_index(
|
||||||
|
field_name="embedding", index_params=default_index)
|
||||||
|
if not status.code:
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||||
|
)
|
||||||
|
return status
|
||||||
|
else:
|
||||||
|
raise Exception(status.message)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to create index: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_collection(self, collection_name):
|
||||||
|
# Delete Milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
self.collection.drop()
|
||||||
|
LOGGER.debug("Successfully drop collection!")
|
||||||
|
return "ok"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to drop collection: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def search_vectors(self, collection_name, vectors, top_k):
|
||||||
|
# Search vector in milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
search_params = {
|
||||||
|
"metric_type": METRIC_TYPE,
|
||||||
|
"params": {
|
||||||
|
"nprobe": 16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res = self.collection.search(
|
||||||
|
vectors,
|
||||||
|
anns_field="embedding",
|
||||||
|
param=search_params,
|
||||||
|
limit=top_k)
|
||||||
|
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def count(self, collection_name):
|
||||||
|
# Get the number of milvus collection
|
||||||
|
try:
|
||||||
|
self.set_collection(collection_name)
|
||||||
|
num = self.collection.num_entities
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||||
|
)
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pymysql
|
||||||
|
from config import MYSQL_DB
|
||||||
|
from config import MYSQL_HOST
|
||||||
|
from config import MYSQL_PORT
|
||||||
|
from config import MYSQL_PWD
|
||||||
|
from config import MYSQL_USER
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
class MySQLHelper():
|
||||||
|
"""
|
||||||
|
the basic operations of PyMySQL
|
||||||
|
|
||||||
|
# This example shows how to:
|
||||||
|
# 1. connect to MySQL server
|
||||||
|
# 2. create a table
|
||||||
|
# 3. insert data to table
|
||||||
|
# 4. search by milvus ids
|
||||||
|
# 5. delete table
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = pymysql.connect(
|
||||||
|
host=MYSQL_HOST,
|
||||||
|
user=MYSQL_USER,
|
||||||
|
port=MYSQL_PORT,
|
||||||
|
password=MYSQL_PWD,
|
||||||
|
database=MYSQL_DB,
|
||||||
|
local_infile=True)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
def test_connection(self):
|
||||||
|
try:
|
||||||
|
self.conn.ping()
|
||||||
|
except Exception:
|
||||||
|
self.conn = pymysql.connect(
|
||||||
|
host=MYSQL_HOST,
|
||||||
|
user=MYSQL_USER,
|
||||||
|
port=MYSQL_PORT,
|
||||||
|
password=MYSQL_PWD,
|
||||||
|
database=MYSQL_DB,
|
||||||
|
local_infile=True)
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
def create_mysql_table(self, table_name):
|
||||||
|
# Create mysql table if not exists
|
||||||
|
self.test_connection()
|
||||||
|
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def load_data_to_mysql(self, table_name, data):
|
||||||
|
# Batch insert (Milvus_ids, img_path) to mysql
|
||||||
|
self.test_connection()
|
||||||
|
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||||
|
try:
|
||||||
|
self.cursor.executemany(sql, data)
|
||||||
|
self.conn.commit()
|
||||||
|
LOGGER.debug(
|
||||||
|
f"MYSQL loads data to table: {table_name} successfully")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def search_by_milvus_ids(self, ids, table_name):
|
||||||
|
# Get the img_path according to the milvus ids
|
||||||
|
self.test_connection()
|
||||||
|
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||||
|
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
results = [res[0] for res in results]
|
||||||
|
LOGGER.debug("MYSQL search by milvus id.")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_table(self, table_name):
|
||||||
|
# Delete mysql table if exists
|
||||||
|
self.test_connection()
|
||||||
|
sql = "drop table if exists " + table_name + ";"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def delete_all_data(self, table_name):
|
||||||
|
# Delete all the data in mysql table
|
||||||
|
self.test_connection()
|
||||||
|
sql = 'delete from ' + table_name + ';'
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
self.conn.commit()
|
||||||
|
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def count_table(self, table_name):
|
||||||
|
# Get the number of mysql table
|
||||||
|
self.test_connection()
|
||||||
|
sql = "select count(milvus_id) from " + table_name + ";"
|
||||||
|
try:
|
||||||
|
self.cursor.execute(sql)
|
||||||
|
results = self.cursor.fetchall()
|
||||||
|
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||||
|
return results[0][0]
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_count(table_name, milvus_cli):
|
||||||
|
"""
|
||||||
|
Returns the total number of vectors in the system
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
try:
|
||||||
|
if not milvus_cli.has_collection(table_name):
|
||||||
|
return None
|
||||||
|
num = milvus_cli.count(table_name)
|
||||||
|
return num
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error attempting to count table {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Delete the collection of Milvus and MySQL
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
try:
|
||||||
|
if not milvus_cli.has_collection(table_name):
|
||||||
|
return "Collection is not exist"
|
||||||
|
status = milvus_cli.delete_collection(table_name)
|
||||||
|
mysql_cli.delete_table(table_name)
|
||||||
|
return status
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from diskcache import Cache
|
||||||
|
from encode import get_audio_embedding
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def get_audios(path):
|
||||||
|
"""
|
||||||
|
List all wav and aif files recursively under the path folder.
|
||||||
|
"""
|
||||||
|
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||||
|
return [
|
||||||
|
item
|
||||||
|
for sublist in [[os.path.join(dir, file) for file in files]
|
||||||
|
for dir, _, files in list(os.walk(path))]
|
||||||
|
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(audio_dir):
|
||||||
|
"""
|
||||||
|
Get the vector of audio
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache = Cache('./tmp')
|
||||||
|
feats = []
|
||||||
|
names = []
|
||||||
|
audio_list = get_audios(audio_dir)
|
||||||
|
total = len(audio_list)
|
||||||
|
cache['total'] = total
|
||||||
|
for i, audio_path in enumerate(audio_list):
|
||||||
|
norm_feat = get_audio_embedding(audio_path)
|
||||||
|
if norm_feat is None:
|
||||||
|
continue
|
||||||
|
feats.append(norm_feat)
|
||||||
|
names.append(audio_path.encode())
|
||||||
|
cache['current'] = i + 1
|
||||||
|
print(
|
||||||
|
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||||
|
)
|
||||||
|
return feats, names
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def format_data(ids, names):
|
||||||
|
"""
|
||||||
|
Combine the id of the vector and the name of the audio into a list
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i in range(len(ids)):
|
||||||
|
value = (str(ids[i]), names[i])
|
||||||
|
data.append(value)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Import vectors to Milvus and data to Mysql respectively
|
||||||
|
"""
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
vectors, names = extract_features(audio_dir)
|
||||||
|
ids = milvus_cli.insert(table_name, vectors)
|
||||||
|
milvus_cli.create_index(table_name)
|
||||||
|
mysql_cli.create_mysql_table(table_name)
|
||||||
|
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||||
|
return len(ids)
|
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from config import DEFAULT_TABLE
|
||||||
|
from config import TOP_K
|
||||||
|
from encode import get_audio_embedding
|
||||||
|
from logs import LOGGER
|
||||||
|
|
||||||
|
|
||||||
|
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||||
|
"""
|
||||||
|
Search the uploaded audio in Milvus/MySQL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not table_name:
|
||||||
|
table_name = DEFAULT_TABLE
|
||||||
|
feat = get_audio_embedding(audio_path)
|
||||||
|
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||||
|
vids = [str(x.id) for x in vectors[0]]
|
||||||
|
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||||
|
distances = [x.distance for x in vectors[0]]
|
||||||
|
for i in range(len(paths)):
|
||||||
|
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||||
|
paths[i] = tmp
|
||||||
|
distances[i] = (1 - distances[i]) * 100
|
||||||
|
return vids, paths, distances
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error with search: {e}")
|
||||||
|
sys.exit(1)
|
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import gdown
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def download_audio_data():
|
||||||
|
"""
|
||||||
|
download audio data
|
||||||
|
"""
|
||||||
|
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||||
|
gdown.download(url)
|
||||||
|
|
||||||
|
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||||
|
zip_ref.extractall('./example_audio')
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop():
|
||||||
|
"""
|
||||||
|
Delete the collection of Milvus and MySQL
|
||||||
|
"""
|
||||||
|
response = client.post("/audio/drop")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_load():
|
||||||
|
"""
|
||||||
|
Insert all the audio files under the file path to Milvus/MySQL
|
||||||
|
"""
|
||||||
|
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
'status': True,
|
||||||
|
'msg': "Successfully loaded data!"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress():
|
||||||
|
"""
|
||||||
|
Get the progress of dealing with data
|
||||||
|
"""
|
||||||
|
response = client.get("/progress")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == "current: 20, total: 20"
|
||||||
|
|
||||||
|
|
||||||
|
def test_count():
|
||||||
|
"""
|
||||||
|
Returns the total number of vectors in the system
|
||||||
|
"""
|
||||||
|
response = client.get("audio/count")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_search():
|
||||||
|
"""
|
||||||
|
Search the uploaded audio in Milvus/MySQL
|
||||||
|
"""
|
||||||
|
response = client.post(
|
||||||
|
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert len(response.json()) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_data():
|
||||||
|
"""
|
||||||
|
Get the audio file
|
||||||
|
"""
|
||||||
|
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
download_audio_data()
|
||||||
|
test_load()
|
||||||
|
test_count()
|
||||||
|
test_search()
|
||||||
|
test_drop()
|
@ -0,0 +1 @@
|
|||||||
|
*.wav
|
@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
|
||||||
|
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input ./zh.wav --topk 1
|
@ -1,27 +1,137 @@
|
|||||||
# This is the parameter configuration file for PaddleSpeech Serving.
|
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||||
|
|
||||||
##################################################################
|
#################################################################################
|
||||||
# SERVER SETTING #
|
# SERVER SETTING #
|
||||||
##################################################################
|
#################################################################################
|
||||||
host: '127.0.0.1'
|
host: 127.0.0.1
|
||||||
port: 8090
|
port: 8090
|
||||||
|
|
||||||
##################################################################
|
# The task format in the engin_list is: <speech task>_<engine type>
|
||||||
# CONFIG FILE #
|
# task choices = ['asr_python', 'asr_inference', 'tts_python', 'tts_inference']
|
||||||
##################################################################
|
|
||||||
# add engine backend type (Options: asr, tts) and config file here.
|
|
||||||
# Adding a speech task to engine_backend means starting the service.
|
|
||||||
engine_backend:
|
|
||||||
asr: 'conf/asr/asr.yaml'
|
|
||||||
tts: 'conf/tts/tts.yaml'
|
|
||||||
|
|
||||||
# The engine_type of speech task needs to keep the same type as the config file of speech task.
|
|
||||||
# E.g: The engine_type of asr is 'python', the engine_backend of asr is 'XX/asr.yaml'
|
|
||||||
# E.g: The engine_type of asr is 'inference', the engine_backend of asr is 'XX/asr_pd.yaml'
|
|
||||||
#
|
|
||||||
# add engine type (Options: python, inference)
|
|
||||||
engine_type:
|
|
||||||
asr: 'python'
|
|
||||||
tts: 'python'
|
|
||||||
|
|
||||||
|
engine_list: ['asr_python', 'tts_python', 'cls_python']
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# ENGINE CONFIG #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
################################### ASR #########################################
|
||||||
|
################### speech task: asr; engine_type: python #######################
|
||||||
|
asr_python:
|
||||||
|
model: 'conformer_wenetspeech'
|
||||||
|
lang: 'zh'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path: # [optional]
|
||||||
|
ckpt_path: # [optional]
|
||||||
|
decode_method: 'attention_rescoring'
|
||||||
|
force_yes: True
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
################### speech task: asr; engine_type: inference #######################
|
||||||
|
asr_inference:
|
||||||
|
# model_type choices=['deepspeech2offline_aishell']
|
||||||
|
model_type: 'deepspeech2offline_aishell'
|
||||||
|
am_model: # the pdmodel file of am static model [optional]
|
||||||
|
am_params: # the pdiparams file of am static model [optional]
|
||||||
|
lang: 'zh'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path:
|
||||||
|
decode_method:
|
||||||
|
force_yes: True
|
||||||
|
|
||||||
|
am_predictor_conf:
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
switch_ir_optim: True
|
||||||
|
glog_info: False # True -> print glog
|
||||||
|
summary: True # False -> do not show predictor config
|
||||||
|
|
||||||
|
|
||||||
|
################################### TTS #########################################
|
||||||
|
################### speech task: tts; engine_type: python #######################
|
||||||
|
tts_python:
|
||||||
|
# am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
|
||||||
|
# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
|
||||||
|
# 'fastspeech2_vctk']
|
||||||
|
am: 'fastspeech2_csmsc'
|
||||||
|
am_config:
|
||||||
|
am_ckpt:
|
||||||
|
am_stat:
|
||||||
|
phones_dict:
|
||||||
|
tones_dict:
|
||||||
|
speaker_dict:
|
||||||
|
spk_id: 0
|
||||||
|
|
||||||
|
# voc (vocoder) choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
|
||||||
|
# 'pwgan_vctk', 'mb_melgan_csmsc']
|
||||||
|
voc: 'pwgan_csmsc'
|
||||||
|
voc_config:
|
||||||
|
voc_ckpt:
|
||||||
|
voc_stat:
|
||||||
|
|
||||||
|
# others
|
||||||
|
lang: 'zh'
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
################### speech task: tts; engine_type: inference #######################
|
||||||
|
tts_inference:
|
||||||
|
# am (acoustic model) choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
|
||||||
|
am: 'fastspeech2_csmsc'
|
||||||
|
am_model: # the pdmodel file of your am static model (XX.pdmodel)
|
||||||
|
am_params: # the pdiparams file of your am static model (XX.pdipparams)
|
||||||
|
am_sample_rate: 24000
|
||||||
|
phones_dict:
|
||||||
|
tones_dict:
|
||||||
|
speaker_dict:
|
||||||
|
spk_id: 0
|
||||||
|
|
||||||
|
am_predictor_conf:
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
switch_ir_optim: True
|
||||||
|
glog_info: False # True -> print glog
|
||||||
|
summary: True # False -> do not show predictor config
|
||||||
|
|
||||||
|
# voc (vocoder) choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
|
||||||
|
voc: 'pwgan_csmsc'
|
||||||
|
voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
|
||||||
|
voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
|
||||||
|
voc_sample_rate: 24000
|
||||||
|
|
||||||
|
voc_predictor_conf:
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
switch_ir_optim: True
|
||||||
|
glog_info: False # True -> print glog
|
||||||
|
summary: True # False -> do not show predictor config
|
||||||
|
|
||||||
|
# others
|
||||||
|
lang: 'zh'
|
||||||
|
|
||||||
|
|
||||||
|
################################### CLS #########################################
|
||||||
|
################### speech task: cls; engine_type: python #######################
|
||||||
|
cls_python:
|
||||||
|
# model choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
|
||||||
|
model: 'panns_cnn14'
|
||||||
|
cfg_path: # [optional] Config of cls task.
|
||||||
|
ckpt_path: # [optional] Checkpoint file of model.
|
||||||
|
label_file: # [optional] Label file of cls task.
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
################### speech task: cls; engine_type: inference #######################
|
||||||
|
cls_inference:
|
||||||
|
# model_type choices=['panns_cnn14', 'panns_cnn10', 'panns_cnn6']
|
||||||
|
model_type: 'panns_cnn14'
|
||||||
|
cfg_path:
|
||||||
|
model_path: # the pdmodel file of am static model [optional]
|
||||||
|
params_path: # the pdiparams file of am static model [optional]
|
||||||
|
label_file: # [optional] Label file of cls task.
|
||||||
|
|
||||||
|
predictor_conf:
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
switch_ir_optim: True
|
||||||
|
glog_info: False # True -> print glog
|
||||||
|
summary: True # False -> do not show predictor config
|
||||||
|
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
model: 'conformer_wenetspeech'
|
|
||||||
lang: 'zh'
|
|
||||||
sample_rate: 16000
|
|
||||||
cfg_path: # [optional]
|
|
||||||
ckpt_path: # [optional]
|
|
||||||
decode_method: 'attention_rescoring'
|
|
||||||
force_yes: True
|
|
||||||
device: # set 'gpu:id' or 'cpu'
|
|
@ -1,26 +0,0 @@
|
|||||||
# This is the parameter configuration file for ASR server.
|
|
||||||
# These are the static models that support paddle inference.
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# ACOUSTIC MODEL SETTING #
|
|
||||||
# am choices=['deepspeech2offline_aishell'] TODO
|
|
||||||
##################################################################
|
|
||||||
model_type: 'deepspeech2offline_aishell'
|
|
||||||
am_model: # the pdmodel file of am static model [optional]
|
|
||||||
am_params: # the pdiparams file of am static model [optional]
|
|
||||||
lang: 'zh'
|
|
||||||
sample_rate: 16000
|
|
||||||
cfg_path:
|
|
||||||
decode_method:
|
|
||||||
force_yes: True
|
|
||||||
|
|
||||||
am_predictor_conf:
|
|
||||||
device: # set 'gpu:id' or 'cpu'
|
|
||||||
switch_ir_optim: True
|
|
||||||
glog_info: False # True -> print glog
|
|
||||||
summary: True # False -> do not show predictor config
|
|
||||||
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# OTHERS #
|
|
||||||
##################################################################
|
|
@ -1,32 +0,0 @@
|
|||||||
# This is the parameter configuration file for TTS server.
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# ACOUSTIC MODEL SETTING #
|
|
||||||
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc',
|
|
||||||
# 'fastspeech2_ljspeech', 'fastspeech2_aishell3',
|
|
||||||
# 'fastspeech2_vctk']
|
|
||||||
##################################################################
|
|
||||||
am: 'fastspeech2_csmsc'
|
|
||||||
am_config:
|
|
||||||
am_ckpt:
|
|
||||||
am_stat:
|
|
||||||
phones_dict:
|
|
||||||
tones_dict:
|
|
||||||
speaker_dict:
|
|
||||||
spk_id: 0
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# VOCODER SETTING #
|
|
||||||
# voc choices=['pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3',
|
|
||||||
# 'pwgan_vctk', 'mb_melgan_csmsc']
|
|
||||||
##################################################################
|
|
||||||
voc: 'pwgan_csmsc'
|
|
||||||
voc_config:
|
|
||||||
voc_ckpt:
|
|
||||||
voc_stat:
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# OTHERS #
|
|
||||||
##################################################################
|
|
||||||
lang: 'zh'
|
|
||||||
device: # set 'gpu:id' or 'cpu'
|
|
@ -1,42 +0,0 @@
|
|||||||
# This is the parameter configuration file for TTS server.
|
|
||||||
# These are the static models that support paddle inference.
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# ACOUSTIC MODEL SETTING #
|
|
||||||
# am choices=['speedyspeech_csmsc', 'fastspeech2_csmsc']
|
|
||||||
##################################################################
|
|
||||||
am: 'fastspeech2_csmsc'
|
|
||||||
am_model: # the pdmodel file of your am static model (XX.pdmodel)
|
|
||||||
am_params: # the pdiparams file of your am static model (XX.pdipparams)
|
|
||||||
am_sample_rate: 24000
|
|
||||||
phones_dict:
|
|
||||||
tones_dict:
|
|
||||||
speaker_dict:
|
|
||||||
spk_id: 0
|
|
||||||
|
|
||||||
am_predictor_conf:
|
|
||||||
device: # set 'gpu:id' or 'cpu'
|
|
||||||
switch_ir_optim: True
|
|
||||||
glog_info: False # True -> print glog
|
|
||||||
summary: True # False -> do not show predictor config
|
|
||||||
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# VOCODER SETTING #
|
|
||||||
# voc choices=['pwgan_csmsc', 'mb_melgan_csmsc','hifigan_csmsc']
|
|
||||||
##################################################################
|
|
||||||
voc: 'pwgan_csmsc'
|
|
||||||
voc_model: # the pdmodel file of your vocoder static model (XX.pdmodel)
|
|
||||||
voc_params: # the pdiparams file of your vocoder static model (XX.pdipparams)
|
|
||||||
voc_sample_rate: 24000
|
|
||||||
|
|
||||||
voc_predictor_conf:
|
|
||||||
device: # set 'gpu:id' or 'cpu'
|
|
||||||
switch_ir_optim: True
|
|
||||||
glog_info: False # True -> print glog
|
|
||||||
summary: True # False -> do not show predictor config
|
|
||||||
|
|
||||||
##################################################################
|
|
||||||
# OTHERS #
|
|
||||||
##################################################################
|
|
||||||
lang: 'zh'
|
|
@ -1,3 +1,3 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
paddlespeech_server start --config_file ./conf/application.yaml
|
paddlespeech_server start --config_file ./conf/application.yaml
|
||||||
|
@ -0,0 +1,156 @@
|
|||||||
|
# HiFiGAN with AISHELL-3
|
||||||
|
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [AISHELL-3](http://www.aishelltech.com/aishell_3).
|
||||||
|
|
||||||
|
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems.
|
||||||
|
## Dataset
|
||||||
|
### Download and Extract
|
||||||
|
Download AISHELL-3.
|
||||||
|
```bash
|
||||||
|
wget https://www.openslr.org/resources/93/data_aishell3.tgz
|
||||||
|
```
|
||||||
|
Extract AISHELL-3.
|
||||||
|
```bash
|
||||||
|
mkdir data_aishell3
|
||||||
|
tar zxvf data_aishell3.tgz -C data_aishell3
|
||||||
|
```
|
||||||
|
### Get MFA Result and Extract
|
||||||
|
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
|
||||||
|
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
|
||||||
|
|
||||||
|
## Get Started
|
||||||
|
Assume the path to the dataset is `~/datasets/data_aishell3`.
|
||||||
|
Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
|
||||||
|
Run the command below to
|
||||||
|
1. **source path**.
|
||||||
|
2. preprocess the dataset.
|
||||||
|
3. train the model.
|
||||||
|
4. synthesize wavs.
|
||||||
|
- synthesize waveform from `metadata.jsonl`.
|
||||||
|
```bash
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, run the following command will only preprocess the dataset.
|
||||||
|
```bash
|
||||||
|
./run.sh --stage 0 --stop-stage 0
|
||||||
|
```
|
||||||
|
### Data Preprocessing
|
||||||
|
```bash
|
||||||
|
./local/preprocess.sh ${conf_path}
|
||||||
|
```
|
||||||
|
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||||
|
|
||||||
|
```text
|
||||||
|
dump
|
||||||
|
├── dev
|
||||||
|
│ ├── norm
|
||||||
|
│ └── raw
|
||||||
|
├── test
|
||||||
|
│ ├── norm
|
||||||
|
│ └── raw
|
||||||
|
└── train
|
||||||
|
├── norm
|
||||||
|
├── raw
|
||||||
|
└── feats_stats.npy
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
|
||||||
|
|
||||||
|
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
|
||||||
|
|
||||||
|
### Model Training
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||||
|
```
|
||||||
|
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||||
|
Here's the complete help message.
|
||||||
|
|
||||||
|
```text
|
||||||
|
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||||
|
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||||
|
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
|
||||||
|
[--run-benchmark RUN_BENCHMARK]
|
||||||
|
[--profiler_options PROFILER_OPTIONS]
|
||||||
|
|
||||||
|
Train a ParallelWaveGAN model.
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--config CONFIG config file to overwrite default config.
|
||||||
|
--train-metadata TRAIN_METADATA
|
||||||
|
training data.
|
||||||
|
--dev-metadata DEV_METADATA
|
||||||
|
dev data.
|
||||||
|
--output-dir OUTPUT_DIR
|
||||||
|
output dir.
|
||||||
|
--ngpu NGPU if ngpu == 0, use cpu.
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
arguments related to benchmark.
|
||||||
|
|
||||||
|
--batch-size BATCH_SIZE
|
||||||
|
batch size.
|
||||||
|
--max-iter MAX_ITER train max steps.
|
||||||
|
--run-benchmark RUN_BENCHMARK
|
||||||
|
runing benchmark or not, if True, use the --batch-size
|
||||||
|
and --max-iter.
|
||||||
|
--profiler_options PROFILER_OPTIONS
|
||||||
|
The option of profiler, which should be in format
|
||||||
|
"key1=value1;key2=value2;key3=value3".
|
||||||
|
```
|
||||||
|
|
||||||
|
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||||
|
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||||
|
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||||
|
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||||
|
|
||||||
|
### Synthesizing
|
||||||
|
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||||
|
```
|
||||||
|
```text
|
||||||
|
usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
|
||||||
|
[--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
|
||||||
|
[--output-dir OUTPUT_DIR] [--ngpu NGPU]
|
||||||
|
|
||||||
|
Synthesize with GANVocoder.
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--generator-type GENERATOR_TYPE
|
||||||
|
type of GANVocoder, should in {pwgan, mb_melgan,
|
||||||
|
style_melgan, } now
|
||||||
|
--config CONFIG GANVocoder config file.
|
||||||
|
--checkpoint CHECKPOINT
|
||||||
|
snapshot to load.
|
||||||
|
--test-metadata TEST_METADATA
|
||||||
|
dev data.
|
||||||
|
--output-dir OUTPUT_DIR
|
||||||
|
output dir.
|
||||||
|
--ngpu NGPU if ngpu == 0, use cpu.
|
||||||
|
```
|
||||||
|
|
||||||
|
1. `--config` config file. You should use the same config with which the model is trained.
|
||||||
|
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
|
||||||
|
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
|
||||||
|
4. `--output-dir` is the directory to save the synthesized audio files.
|
||||||
|
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||||
|
## Pretrained Models
|
||||||
|
The pretrained model can be downloaded here [hifigan_aishell3_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip).
|
||||||
|
|
||||||
|
|
||||||
|
Model | Step | eval/generator_loss | eval/mel_loss| eval/feature_matching_loss
|
||||||
|
:-------------:| :------------:| :-----: | :-----: | :--------:
|
||||||
|
default| 1(gpu) x 2500000|24.060|0.1068|7.499
|
||||||
|
|
||||||
|
HiFiGAN checkpoint contains files listed below.
|
||||||
|
|
||||||
|
```text
|
||||||
|
hifigan_aishell3_ckpt_0.2.0
|
||||||
|
├── default.yaml # default config used to train hifigan
|
||||||
|
├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan
|
||||||
|
└── snapshot_iter_2500000.pdz # generator parameters of hifigan
|
||||||
|
```
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
|
@ -0,0 +1,168 @@
|
|||||||
|
# This is the configuration file for AISHELL-3 dataset.
|
||||||
|
# This configuration is based on HiFiGAN V1, which is
|
||||||
|
# an official configuration. But I found that the optimizer
|
||||||
|
# setting does not work well with my implementation.
|
||||||
|
# So I changed optimizer settings as follows:
|
||||||
|
# - AdamW -> Adam
|
||||||
|
# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
|
||||||
|
# - Scheduler: ExponentialLR -> MultiStepLR
|
||||||
|
# To match the shift size difference, the upsample scales
|
||||||
|
# is also modified from the original 256 shift setting.
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
fs: 24000 # Sampling rate.
|
||||||
|
n_fft: 2048 # FFT size (samples).
|
||||||
|
n_shift: 300 # Hop size (samples). 12.5ms
|
||||||
|
win_length: 1200 # Window length (samples). 50ms
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
n_mels: 80 # Number of mel basis.
|
||||||
|
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||||
|
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_params:
|
||||||
|
in_channels: 80 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
channels: 512 # Number of initial channels.
|
||||||
|
kernel_size: 7 # Kernel size of initial and final conv layers.
|
||||||
|
upsample_scales: [5, 5, 4, 3] # Upsampling scales.
|
||||||
|
upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers.
|
||||||
|
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
|
||||||
|
resblock_dilations: # Dilations for residual blocks.
|
||||||
|
- [1, 3, 5]
|
||||||
|
- [1, 3, 5]
|
||||||
|
- [1, 3, 5]
|
||||||
|
use_additional_convs: True # Whether to use additional conv layer in residual blocks.
|
||||||
|
bias: True # Whether to use bias parameter in conv.
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation type.
|
||||||
|
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True # Whether to apply weight normalization.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_params:
|
||||||
|
scales: 3 # Number of multi-scale discriminator.
|
||||||
|
scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
|
||||||
|
scale_downsample_pooling_params:
|
||||||
|
kernel_size: 4 # Pooling kernel size.
|
||||||
|
stride: 2 # Pooling stride.
|
||||||
|
padding: 2 # Padding size.
|
||||||
|
scale_discriminator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
|
||||||
|
channels: 128 # Initial number of channels.
|
||||||
|
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||||
|
max_groups: 16 # Maximum number of groups in downsampling conv layers.
|
||||||
|
bias: True
|
||||||
|
downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||||
|
nonlinear_activation_params:
|
||||||
|
negative_slope: 0.1
|
||||||
|
follow_official_norm: True # Whether to follow the official norm setting.
|
||||||
|
periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
|
||||||
|
period_discriminator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_sizes: [5, 3] # List of kernel sizes.
|
||||||
|
channels: 32 # Initial number of channels.
|
||||||
|
downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
|
||||||
|
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||||
|
bias: True # Whether to use bias parameter in conv layer."
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||||
|
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True # Whether to apply weight normalization.
|
||||||
|
use_spectral_norm: False # Whether to apply spectral normalization.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# STFT LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
use_stft_loss: False # Whether to use multi-resolution STFT loss.
|
||||||
|
use_mel_loss: True # Whether to use Mel-spectrogram loss.
|
||||||
|
mel_loss_params:
|
||||||
|
fs: 24000
|
||||||
|
fft_size: 2048
|
||||||
|
hop_size: 300
|
||||||
|
win_length: 1200
|
||||||
|
window: "hann"
|
||||||
|
num_mels: 80
|
||||||
|
fmin: 0
|
||||||
|
fmax: 12000
|
||||||
|
log_base: null
|
||||||
|
generator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
discriminator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
use_feat_match_loss: True
|
||||||
|
feat_match_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
average_by_layers: False # Whether to average loss by #layers in each discriminator.
|
||||||
|
include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# ADVERSARIAL LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
|
||||||
|
lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
|
||||||
|
lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 16 # Batch size.
|
||||||
|
batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||||
|
num_workers: 2 # Number of workers in DataLoader.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER & SCHEDULER SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||||
|
generator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4 # Generator's learning rate.
|
||||||
|
gamma: 0.5 # Generator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 200000
|
||||||
|
- 400000
|
||||||
|
- 600000
|
||||||
|
- 800000
|
||||||
|
generator_grad_norm: -1 # Generator's gradient norm.
|
||||||
|
discriminator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||||
|
discriminator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4 # Discriminator's learning rate.
|
||||||
|
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 200000
|
||||||
|
- 400000
|
||||||
|
- 600000
|
||||||
|
- 800000
|
||||||
|
discriminator_grad_norm: -1 # Discriminator's gradient norm.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# INTERVAL SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_train_start_steps: 1 # Number of steps to start to train discriminator.
|
||||||
|
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
|
||||||
|
train_max_steps: 2500000 # Number of training steps.
|
||||||
|
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||||
|
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./aishell3_alignment_tone \
|
||||||
|
--output=durations.txt \
|
||||||
|
--config=${config_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${BIN_DIR}/../preprocess.py \
|
||||||
|
--rootdir=~/datasets/data_aishell3/ \
|
||||||
|
--dataset=aishell3 \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config=${config_path} \
|
||||||
|
--cut-sil=True \
|
||||||
|
--num-cpu=20
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="feats"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
fi
|
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output-dir=${train_output_path}/test \
|
||||||
|
--generator-type=hifigan
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
|
||||||
|
FLAGS_cudnn_exhaustive_search=true \
|
||||||
|
FLAGS_conv_workspace_size_limit=4000 \
|
||||||
|
python ${BIN_DIR}/train.py \
|
||||||
|
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||||
|
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=1
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=hifigan
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
|
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_5000.pdz
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# synthesize
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
@ -0,0 +1,133 @@
|
|||||||
|
# HiFiGAN with the LJSpeech-1.1
|
||||||
|
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
|
||||||
|
## Dataset
|
||||||
|
### Download and Extract
|
||||||
|
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/).
|
||||||
|
### Get MFA Result and Extract
|
||||||
|
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
|
||||||
|
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
|
||||||
|
|
||||||
|
## Get Started
|
||||||
|
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
|
||||||
|
Assume the path to the MFA result of LJSpeech-1.1 is `./ljspeech_alignment`.
|
||||||
|
Run the command below to
|
||||||
|
1. **source path**.
|
||||||
|
2. preprocess the dataset.
|
||||||
|
3. train the model.
|
||||||
|
4. synthesize wavs.
|
||||||
|
- synthesize waveform from `metadata.jsonl`.
|
||||||
|
```bash
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
|
||||||
|
```bash
|
||||||
|
./run.sh --stage 0 --stop-stage 0
|
||||||
|
```
|
||||||
|
### Data Preprocessing
|
||||||
|
```bash
|
||||||
|
./local/preprocess.sh ${conf_path}
|
||||||
|
```
|
||||||
|
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||||
|
|
||||||
|
```text
|
||||||
|
dump
|
||||||
|
├── dev
|
||||||
|
│ ├── norm
|
||||||
|
│ └── raw
|
||||||
|
├── test
|
||||||
|
│ ├── norm
|
||||||
|
│ └── raw
|
||||||
|
└── train
|
||||||
|
├── norm
|
||||||
|
├── raw
|
||||||
|
└── feats_stats.npy
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
|
||||||
|
|
||||||
|
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
|
||||||
|
|
||||||
|
### Model Training
|
||||||
|
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||||
|
```
|
||||||
|
Here's the complete help message.
|
||||||
|
|
||||||
|
```text
|
||||||
|
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||||
|
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||||
|
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
|
||||||
|
[--run-benchmark RUN_BENCHMARK]
|
||||||
|
[--profiler_options PROFILER_OPTIONS]
|
||||||
|
|
||||||
|
Train a ParallelWaveGAN model.
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--config CONFIG config file to overwrite default config.
|
||||||
|
--train-metadata TRAIN_METADATA
|
||||||
|
training data.
|
||||||
|
--dev-metadata DEV_METADATA
|
||||||
|
dev data.
|
||||||
|
--output-dir OUTPUT_DIR
|
||||||
|
output dir.
|
||||||
|
--ngpu NGPU if ngpu == 0, use cpu.
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
arguments related to benchmark.
|
||||||
|
|
||||||
|
--batch-size BATCH_SIZE
|
||||||
|
batch size.
|
||||||
|
--max-iter MAX_ITER train max steps.
|
||||||
|
--run-benchmark RUN_BENCHMARK
|
||||||
|
runing benchmark or not, if True, use the --batch-size
|
||||||
|
and --max-iter.
|
||||||
|
--profiler_options PROFILER_OPTIONS
|
||||||
|
The option of profiler, which should be in format
|
||||||
|
"key1=value1;key2=value2;key3=value3".
|
||||||
|
```
|
||||||
|
|
||||||
|
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||||
|
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||||
|
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||||
|
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||||
|
|
||||||
|
### Synthesizing
|
||||||
|
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||||
|
```
|
||||||
|
```text
|
||||||
|
usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
|
||||||
|
[--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
|
||||||
|
[--output-dir OUTPUT_DIR] [--ngpu NGPU]
|
||||||
|
|
||||||
|
Synthesize with GANVocoder.
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--generator-type GENERATOR_TYPE
|
||||||
|
type of GANVocoder, should in {pwgan, mb_melgan,
|
||||||
|
style_melgan, } now
|
||||||
|
--config CONFIG GANVocoder config file.
|
||||||
|
--checkpoint CHECKPOINT
|
||||||
|
snapshot to load.
|
||||||
|
--test-metadata TEST_METADATA
|
||||||
|
dev data.
|
||||||
|
--output-dir OUTPUT_DIR
|
||||||
|
output dir.
|
||||||
|
--ngpu NGPU if ngpu == 0, use cpu.
|
||||||
|
```
|
||||||
|
|
||||||
|
1. `--config` parallel wavegan config file. You should use the same config with which the model is trained.
|
||||||
|
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
|
||||||
|
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
|
||||||
|
4. `--output-dir` is the directory to save the synthesized audio files.
|
||||||
|
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||||
|
|
||||||
|
## Pretrained Model
|
||||||
|
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.
|
@ -0,0 +1,167 @@
|
|||||||
|
# This is the configuration file for LJSpeech dataset.
|
||||||
|
# This configuration is based on HiFiGAN V1, which is an official configuration.
|
||||||
|
# But I found that the optimizer setting does not work well with my implementation.
|
||||||
|
# So I changed optimizer settings as follows:
|
||||||
|
# - AdamW -> Adam
|
||||||
|
# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
|
||||||
|
# - Scheduler: ExponentialLR -> MultiStepLR
|
||||||
|
# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
fs: 22050 # Sampling rate.
|
||||||
|
n_fft: 1024 # FFT size (samples).
|
||||||
|
n_shift: 256 # Hop size (samples). 11.6ms
|
||||||
|
win_length: null # Window length (samples).
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
n_mels: 80 # Number of mel basis.
|
||||||
|
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
|
||||||
|
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_params:
|
||||||
|
in_channels: 80 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
channels: 512 # Number of initial channels.
|
||||||
|
kernel_size: 7 # Kernel size of initial and final conv layers.
|
||||||
|
upsample_scales: [8, 8, 2, 2] # Upsampling scales.
|
||||||
|
upsample_kernel_sizes: [16, 16, 4, 4] # Kernel size for upsampling layers.
|
||||||
|
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
|
||||||
|
resblock_dilations: # Dilations for residual blocks.
|
||||||
|
- [1, 3, 5]
|
||||||
|
- [1, 3, 5]
|
||||||
|
- [1, 3, 5]
|
||||||
|
use_additional_convs: True # Whether to use additional conv layer in residual blocks.
|
||||||
|
bias: True # Whether to use bias parameter in conv.
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation type.
|
||||||
|
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True # Whether to apply weight normalization.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_params:
|
||||||
|
scales: 3 # Number of multi-scale discriminator.
|
||||||
|
scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
|
||||||
|
scale_downsample_pooling_params:
|
||||||
|
kernel_size: 4 # Pooling kernel size.
|
||||||
|
stride: 2 # Pooling stride.
|
||||||
|
padding: 2 # Padding size.
|
||||||
|
scale_discriminator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
|
||||||
|
channels: 128 # Initial number of channels.
|
||||||
|
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||||
|
max_groups: 16 # Maximum number of groups in downsampling conv layers.
|
||||||
|
bias: True
|
||||||
|
downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||||
|
nonlinear_activation_params:
|
||||||
|
negative_slope: 0.1
|
||||||
|
follow_official_norm: True # Whether to follow the official norm setting.
|
||||||
|
periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
|
||||||
|
period_discriminator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_sizes: [5, 3] # List of kernel sizes.
|
||||||
|
channels: 32 # Initial number of channels.
|
||||||
|
downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
|
||||||
|
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
|
||||||
|
bias: True # Whether to use bias parameter in conv layer."
|
||||||
|
nonlinear_activation: "leakyrelu" # Nonlinear activation.
|
||||||
|
nonlinear_activation_params: # Nonlinear activation paramters.
|
||||||
|
negative_slope: 0.1
|
||||||
|
use_weight_norm: True # Whether to apply weight normalization.
|
||||||
|
use_spectral_norm: False # Whether to apply spectral normalization.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# STFT LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
use_stft_loss: False # Whether to use multi-resolution STFT loss.
|
||||||
|
use_mel_loss: True # Whether to use Mel-spectrogram loss.
|
||||||
|
mel_loss_params:
|
||||||
|
fs: 22050
|
||||||
|
fft_size: 1024
|
||||||
|
hop_size: 256
|
||||||
|
win_length: null
|
||||||
|
window: "hann"
|
||||||
|
num_mels: 80
|
||||||
|
fmin: 0
|
||||||
|
fmax: 11025
|
||||||
|
log_base: null
|
||||||
|
generator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
discriminator_adv_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
use_feat_match_loss: True
|
||||||
|
feat_match_loss_params:
|
||||||
|
average_by_discriminators: False # Whether to average loss by #discriminators.
|
||||||
|
average_by_layers: False # Whether to average loss by #layers in each discriminator.
|
||||||
|
include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# ADVERSARIAL LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
|
||||||
|
lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
|
||||||
|
lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 16 # Batch size.
|
||||||
|
batch_max_steps: 8192 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||||
|
num_workers: 2 # Number of workers in DataLoader.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER & SCHEDULER SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||||
|
generator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4 # Generator's learning rate.
|
||||||
|
gamma: 0.5 # Generator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 200000
|
||||||
|
- 400000
|
||||||
|
- 600000
|
||||||
|
- 800000
|
||||||
|
generator_grad_norm: -1 # Generator's gradient norm.
|
||||||
|
discriminator_optimizer_params:
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.9
|
||||||
|
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||||
|
discriminator_scheduler_params:
|
||||||
|
learning_rate: 2.0e-4 # Discriminator's learning rate.
|
||||||
|
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||||
|
milestones: # At each milestone, lr will be multiplied by gamma.
|
||||||
|
- 200000
|
||||||
|
- 400000
|
||||||
|
- 600000
|
||||||
|
- 800000
|
||||||
|
discriminator_grad_norm: -1 # Discriminator's gradient norm.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# INTERVAL SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_train_start_steps: 1 # Number of steps to start to train discriminator.
|
||||||
|
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
|
||||||
|
train_max_steps: 2500000 # Number of training steps.
|
||||||
|
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||||
|
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./ljspeech_alignment \
|
||||||
|
--output=durations.txt \
|
||||||
|
--config=${config_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${BIN_DIR}/../preprocess.py \
|
||||||
|
--rootdir=~/datasets/LJSpeech-1.1/ \
|
||||||
|
--dataset=ljspeech \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config=${config_path} \
|
||||||
|
--cut-sil=True \
|
||||||
|
--num-cpu=20
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="feats"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/../normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--stats=dump/train/feats_stats.npy
|
||||||
|
fi
|
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/../synthesize.py \
|
||||||
|
--config=${config_path} \
|
||||||
|
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output-dir=${train_output_path}/test \
|
||||||
|
--generator-type=hifigan
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
|
||||||
|
FLAGS_cudnn_exhaustive_search=true \
|
||||||
|
FLAGS_conv_workspace_size_limit=4000 \
|
||||||
|
python ${BIN_DIR}/train.py \
|
||||||
|
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||||
|
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=1
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=hifigan
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}
|
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_5000.pdz
|
||||||
|
|
||||||
|
# with the following command, you can choose the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
./local/preprocess.sh ${conf_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# synthesize
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
@ -1 +1,9 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
Date: 2022-3-15, Author: Xiaojie Chen.
|
||||||
|
- kaldi and librosa mfcc, fbank, spectrogram.
|
||||||
|
- unit test and benchmark.
|
||||||
|
|
||||||
|
Date: 2022-2-25, Author: Hui Zhang.
|
||||||
|
- Refactor architecture.
|
||||||
|
- dtw distance and mcd style dtw.
|
||||||
|
@ -1,170 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from numpy import ndarray as array
|
|
||||||
|
|
||||||
from ..backends import depth_convert
|
|
||||||
from ..utils import ParameterError
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'depth_augment',
|
|
||||||
'spect_augment',
|
|
||||||
'random_crop1d',
|
|
||||||
'random_crop2d',
|
|
||||||
'adaptive_spect_augment',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def randint(high: int) -> int:
|
|
||||||
"""Generate one random integer in range [0 high)
|
|
||||||
|
|
||||||
This is a helper function for random data augmentaiton
|
|
||||||
"""
|
|
||||||
return int(np.random.randint(0, high=high))
|
|
||||||
|
|
||||||
|
|
||||||
def rand() -> float:
|
|
||||||
"""Generate one floating-point number in range [0 1)
|
|
||||||
|
|
||||||
This is a helper function for random data augmentaiton
|
|
||||||
"""
|
|
||||||
return float(np.random.rand(1))
|
|
||||||
|
|
||||||
|
|
||||||
def depth_augment(y: array,
|
|
||||||
choices: List=['int8', 'int16'],
|
|
||||||
probs: List[float]=[0.5, 0.5]) -> array:
|
|
||||||
""" Audio depth augmentation
|
|
||||||
|
|
||||||
Do audio depth augmentation to simulate the distortion brought by quantization.
|
|
||||||
"""
|
|
||||||
assert len(probs) == len(
|
|
||||||
choices
|
|
||||||
), 'number of choices {} must be equal to size of probs {}'.format(
|
|
||||||
len(choices), len(probs))
|
|
||||||
depth = np.random.choice(choices, p=probs)
|
|
||||||
src_depth = y.dtype
|
|
||||||
y1 = depth_convert(y, depth)
|
|
||||||
y2 = depth_convert(y1, src_depth)
|
|
||||||
|
|
||||||
return y2
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_spect_augment(spect: array, tempo_axis: int=0,
|
|
||||||
level: float=0.1) -> array:
|
|
||||||
"""Do adpative spectrogram augmentation
|
|
||||||
|
|
||||||
The level of the augmentation is gowern by the paramter level,
|
|
||||||
ranging from 0 to 1, with 0 represents no augmentation。
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
|
|
||||||
if tempo_axis == 0:
|
|
||||||
nt, nf = spect.shape
|
|
||||||
else:
|
|
||||||
nf, nt = spect.shape
|
|
||||||
|
|
||||||
time_mask_width = int(nt * level * 0.5)
|
|
||||||
freq_mask_width = int(nf * level * 0.5)
|
|
||||||
|
|
||||||
num_time_mask = int(10 * level)
|
|
||||||
num_freq_mask = int(10 * level)
|
|
||||||
|
|
||||||
if tempo_axis == 0:
|
|
||||||
for _ in range(num_time_mask):
|
|
||||||
start = randint(nt - time_mask_width)
|
|
||||||
spect[start:start + time_mask_width, :] = 0
|
|
||||||
for _ in range(num_freq_mask):
|
|
||||||
start = randint(nf - freq_mask_width)
|
|
||||||
spect[:, start:start + freq_mask_width] = 0
|
|
||||||
else:
|
|
||||||
for _ in range(num_time_mask):
|
|
||||||
start = randint(nt - time_mask_width)
|
|
||||||
spect[:, start:start + time_mask_width] = 0
|
|
||||||
for _ in range(num_freq_mask):
|
|
||||||
start = randint(nf - freq_mask_width)
|
|
||||||
spect[start:start + freq_mask_width, :] = 0
|
|
||||||
|
|
||||||
return spect
|
|
||||||
|
|
||||||
|
|
||||||
def spect_augment(spect: array,
|
|
||||||
tempo_axis: int=0,
|
|
||||||
max_time_mask: int=3,
|
|
||||||
max_freq_mask: int=3,
|
|
||||||
max_time_mask_width: int=30,
|
|
||||||
max_freq_mask_width: int=20) -> array:
|
|
||||||
"""Do spectrogram augmentation in both time and freq axis
|
|
||||||
|
|
||||||
Reference:
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert spect.ndim == 2., 'only supports 2d tensor or numpy array'
|
|
||||||
if tempo_axis == 0:
|
|
||||||
nt, nf = spect.shape
|
|
||||||
else:
|
|
||||||
nf, nt = spect.shape
|
|
||||||
|
|
||||||
num_time_mask = randint(max_time_mask)
|
|
||||||
num_freq_mask = randint(max_freq_mask)
|
|
||||||
|
|
||||||
time_mask_width = randint(max_time_mask_width)
|
|
||||||
freq_mask_width = randint(max_freq_mask_width)
|
|
||||||
|
|
||||||
if tempo_axis == 0:
|
|
||||||
for _ in range(num_time_mask):
|
|
||||||
start = randint(nt - time_mask_width)
|
|
||||||
spect[start:start + time_mask_width, :] = 0
|
|
||||||
for _ in range(num_freq_mask):
|
|
||||||
start = randint(nf - freq_mask_width)
|
|
||||||
spect[:, start:start + freq_mask_width] = 0
|
|
||||||
else:
|
|
||||||
for _ in range(num_time_mask):
|
|
||||||
start = randint(nt - time_mask_width)
|
|
||||||
spect[:, start:start + time_mask_width] = 0
|
|
||||||
for _ in range(num_freq_mask):
|
|
||||||
start = randint(nf - freq_mask_width)
|
|
||||||
spect[start:start + freq_mask_width, :] = 0
|
|
||||||
|
|
||||||
return spect
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop1d(y: array, crop_len: int) -> array:
|
|
||||||
""" Do random cropping on 1d input signal
|
|
||||||
|
|
||||||
The input is a 1d signal, typically a sound waveform
|
|
||||||
"""
|
|
||||||
if y.ndim != 1:
|
|
||||||
'only accept 1d tensor or numpy array'
|
|
||||||
n = len(y)
|
|
||||||
idx = randint(n - crop_len)
|
|
||||||
return y[idx:idx + crop_len]
|
|
||||||
|
|
||||||
|
|
||||||
def random_crop2d(s: array, crop_len: int, tempo_axis: int=0) -> array:
|
|
||||||
""" Do random cropping for 2D array, typically a spectrogram.
|
|
||||||
|
|
||||||
The cropping is done in temporal direction on the time-freq input signal.
|
|
||||||
"""
|
|
||||||
if tempo_axis >= s.ndim:
|
|
||||||
raise ParameterError('axis out of range')
|
|
||||||
|
|
||||||
n = s.shape[tempo_axis]
|
|
||||||
idx = randint(high=n - crop_len)
|
|
||||||
sli = [slice(None) for i in range(s.ndim)]
|
|
||||||
sli[tempo_axis] = slice(idx, idx + crop_len)
|
|
||||||
out = s[tuple(sli)]
|
|
||||||
return out
|
|
@ -1,461 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import math
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
import paddle.nn as nn
|
|
||||||
|
|
||||||
from .window import get_window
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'Spectrogram',
|
|
||||||
'MelSpectrogram',
|
|
||||||
'LogMelSpectrogram',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def hz_to_mel(freq: Union[paddle.Tensor, float],
|
|
||||||
htk: bool=False) -> Union[paddle.Tensor, float]:
|
|
||||||
"""Convert Hz to Mels.
|
|
||||||
Parameters:
|
|
||||||
freq: the input tensor of arbitrary shape, or a single floating point number.
|
|
||||||
htk: use HTK formula to do the conversion.
|
|
||||||
The default value is False.
|
|
||||||
Returns:
|
|
||||||
The frequencies represented in Mel-scale.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if htk:
|
|
||||||
if isinstance(freq, paddle.Tensor):
|
|
||||||
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
|
|
||||||
else:
|
|
||||||
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
|
||||||
|
|
||||||
# Fill in the linear part
|
|
||||||
f_min = 0.0
|
|
||||||
f_sp = 200.0 / 3
|
|
||||||
|
|
||||||
mels = (freq - f_min) / f_sp
|
|
||||||
|
|
||||||
# Fill in the log-scale part
|
|
||||||
|
|
||||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
|
||||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
|
||||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
|
||||||
|
|
||||||
if isinstance(freq, paddle.Tensor):
|
|
||||||
target = min_log_mel + paddle.log(
|
|
||||||
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
|
|
||||||
mask = (freq > min_log_hz).astype(freq.dtype)
|
|
||||||
mels = target * mask + mels * (
|
|
||||||
1 - mask) # will replace by masked_fill OP in future
|
|
||||||
else:
|
|
||||||
if freq >= min_log_hz:
|
|
||||||
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
|
|
||||||
|
|
||||||
return mels
|
|
||||||
|
|
||||||
|
|
||||||
def mel_to_hz(mel: Union[float, paddle.Tensor],
|
|
||||||
htk: bool=False) -> Union[float, paddle.Tensor]:
|
|
||||||
"""Convert mel bin numbers to frequencies.
|
|
||||||
Parameters:
|
|
||||||
mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
|
|
||||||
htk: use HTK formula to do the conversion.
|
|
||||||
Returns:
|
|
||||||
The frequencies represented in hz.
|
|
||||||
"""
|
|
||||||
if htk:
|
|
||||||
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
|
|
||||||
|
|
||||||
f_min = 0.0
|
|
||||||
f_sp = 200.0 / 3
|
|
||||||
freqs = f_min + f_sp * mel
|
|
||||||
# And now the nonlinear scale
|
|
||||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
|
||||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
|
||||||
logstep = math.log(6.4) / 27.0 # step size for log region
|
|
||||||
if isinstance(mel, paddle.Tensor):
|
|
||||||
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
|
|
||||||
mask = (mel > min_log_mel).astype(mel.dtype)
|
|
||||||
freqs = target * mask + freqs * (
|
|
||||||
1 - mask) # will replace by masked_fill OP in future
|
|
||||||
else:
|
|
||||||
if mel >= min_log_mel:
|
|
||||||
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
|
|
||||||
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
|
|
||||||
def mel_frequencies(n_mels: int=64,
|
|
||||||
f_min: float=0.0,
|
|
||||||
f_max: float=11025.0,
|
|
||||||
htk: bool=False,
|
|
||||||
dtype: str=paddle.float32):
|
|
||||||
"""Compute mel frequencies.
|
|
||||||
Parameters:
|
|
||||||
n_mels(int): number of Mel bins.
|
|
||||||
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
|
||||||
f_max(float): the upper cut-off frequency, above which the filter response is zero.
|
|
||||||
htk(bool): whether to use htk formula.
|
|
||||||
dtype(str): the datatype of the return frequencies.
|
|
||||||
Returns:
|
|
||||||
The frequencies represented in Mel-scale
|
|
||||||
"""
|
|
||||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
|
||||||
min_mel = hz_to_mel(f_min, htk=htk)
|
|
||||||
max_mel = hz_to_mel(f_max, htk=htk)
|
|
||||||
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
|
|
||||||
freqs = mel_to_hz(mels, htk=htk)
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
|
|
||||||
def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
|
|
||||||
"""Compute fourier frequencies.
|
|
||||||
Parameters:
|
|
||||||
sr(int): the audio sample rate.
|
|
||||||
n_fft(float): the number of fft bins.
|
|
||||||
dtype(str): the datatype of the return frequencies.
|
|
||||||
Returns:
|
|
||||||
The frequencies represented in hz.
|
|
||||||
"""
|
|
||||||
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_matrix(sr: int,
|
|
||||||
n_fft: int,
|
|
||||||
n_mels: int=64,
|
|
||||||
f_min: float=0.0,
|
|
||||||
f_max: Optional[float]=None,
|
|
||||||
htk: bool=False,
|
|
||||||
norm: Union[str, float]='slaney',
|
|
||||||
dtype: str=paddle.float32):
|
|
||||||
"""Compute fbank matrix.
|
|
||||||
Parameters:
|
|
||||||
sr(int): the audio sample rate.
|
|
||||||
n_fft(int): the number of fft bins.
|
|
||||||
n_mels(int): the number of Mel bins.
|
|
||||||
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
|
||||||
f_max(float): the upper cut-off frequency, above which the filter response is zero.
|
|
||||||
htk: whether to use htk formula.
|
|
||||||
return_complex(bool): whether to return complex matrix. If True, the matrix will
|
|
||||||
be complex type. Otherwise, the real and image part will be stored in the last
|
|
||||||
axis of returned tensor.
|
|
||||||
dtype(str): the datatype of the returned fbank matrix.
|
|
||||||
Returns:
|
|
||||||
The fbank matrix of shape (n_mels, int(1+n_fft//2)).
|
|
||||||
Shape:
|
|
||||||
output: (n_mels, int(1+n_fft//2))
|
|
||||||
"""
|
|
||||||
|
|
||||||
if f_max is None:
|
|
||||||
f_max = float(sr) / 2
|
|
||||||
|
|
||||||
# Initialize the weights
|
|
||||||
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
|
||||||
|
|
||||||
# Center freqs of each FFT bin
|
|
||||||
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
|
|
||||||
|
|
||||||
# 'Center freqs' of mel bands - uniformly spaced between limits
|
|
||||||
mel_f = mel_frequencies(
|
|
||||||
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
|
|
||||||
|
|
||||||
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
|
|
||||||
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
|
|
||||||
#ramps = np.subtract.outer(mel_f, fftfreqs)
|
|
||||||
|
|
||||||
for i in range(n_mels):
|
|
||||||
# lower and upper slopes for all bins
|
|
||||||
lower = -ramps[i] / fdiff[i]
|
|
||||||
upper = ramps[i + 2] / fdiff[i + 1]
|
|
||||||
|
|
||||||
# .. then intersect them with each other and zero
|
|
||||||
weights[i] = paddle.maximum(
|
|
||||||
paddle.zeros_like(lower), paddle.minimum(lower, upper))
|
|
||||||
|
|
||||||
# Slaney-style mel is scaled to be approx constant energy per channel
|
|
||||||
if norm == 'slaney':
|
|
||||||
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
|
|
||||||
weights *= enorm.unsqueeze(1)
|
|
||||||
elif isinstance(norm, int) or isinstance(norm, float):
|
|
||||||
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
|
|
||||||
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
def power_to_db(magnitude: paddle.Tensor,
|
|
||||||
ref_value: float=1.0,
|
|
||||||
amin: float=1e-10,
|
|
||||||
top_db: Optional[float]=None) -> paddle.Tensor:
|
|
||||||
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units.
|
|
||||||
The function computes the scaling ``10 * log10(x / ref)`` in a numerically
|
|
||||||
stable way.
|
|
||||||
Parameters:
|
|
||||||
magnitude(Tensor): the input magnitude tensor of any shape.
|
|
||||||
ref_value(float): the reference value. If smaller than 1.0, the db level
|
|
||||||
of the signal will be pulled up accordingly. Otherwise, the db level
|
|
||||||
is pushed down.
|
|
||||||
amin(float): the minimum value of input magnitude, below which the input
|
|
||||||
magnitude is clipped(to amin).
|
|
||||||
top_db(float): the maximum db value of resulting spectrum, above which the
|
|
||||||
spectrum is clipped(to top_db).
|
|
||||||
Returns:
|
|
||||||
The spectrogram in log-scale.
|
|
||||||
shape:
|
|
||||||
input: any shape
|
|
||||||
output: same as input
|
|
||||||
"""
|
|
||||||
if amin <= 0:
|
|
||||||
raise Exception("amin must be strictly positive")
|
|
||||||
|
|
||||||
if ref_value <= 0:
|
|
||||||
raise Exception("ref_value must be strictly positive")
|
|
||||||
|
|
||||||
ones = paddle.ones_like(magnitude)
|
|
||||||
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
|
|
||||||
log_spec -= 10.0 * math.log10(max(ref_value, amin))
|
|
||||||
|
|
||||||
if top_db is not None:
|
|
||||||
if top_db < 0:
|
|
||||||
raise Exception("top_db must be non-negative")
|
|
||||||
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
|
|
||||||
|
|
||||||
return log_spec
|
|
||||||
|
|
||||||
|
|
||||||
class Spectrogram(nn.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
n_fft: int=512,
|
|
||||||
hop_length: Optional[int]=None,
|
|
||||||
win_length: Optional[int]=None,
|
|
||||||
window: str='hann',
|
|
||||||
center: bool=True,
|
|
||||||
pad_mode: str='reflect',
|
|
||||||
dtype: str=paddle.float32):
|
|
||||||
"""Compute spectrogram of a given signal, typically an audio waveform.
|
|
||||||
The spectorgram is defined as the complex norm of the short-time
|
|
||||||
Fourier transformation.
|
|
||||||
Parameters:
|
|
||||||
n_fft(int): the number of frequency components of the discrete Fourier transform.
|
|
||||||
The default value is 2048,
|
|
||||||
hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
|
||||||
The default value is None.
|
|
||||||
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
|
||||||
The default value is None.
|
|
||||||
window(str): the name of the window function applied to the single before the Fourier transform.
|
|
||||||
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
|
||||||
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
|
||||||
The default value is 'hann'
|
|
||||||
center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
|
||||||
If False, frame t begins at x[t * hop_length]
|
|
||||||
The default value is True
|
|
||||||
pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
|
||||||
and 'constant'. The default value is 'reflect'.
|
|
||||||
dtype(str): the data type of input and window.
|
|
||||||
Notes:
|
|
||||||
The Spectrogram transform relies on STFT transform to compute the spectrogram.
|
|
||||||
By default, the weights are not learnable. To fine-tune the Fourier coefficients,
|
|
||||||
set stop_gradient=False before training.
|
|
||||||
For more information, see STFT().
|
|
||||||
"""
|
|
||||||
super(Spectrogram, self).__init__()
|
|
||||||
|
|
||||||
if win_length is None:
|
|
||||||
win_length = n_fft
|
|
||||||
|
|
||||||
fft_window = get_window(window, win_length, fftbins=True, dtype=dtype)
|
|
||||||
self._stft = partial(
|
|
||||||
paddle.signal.stft,
|
|
||||||
n_fft=n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=fft_window,
|
|
||||||
center=center,
|
|
||||||
pad_mode=pad_mode)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
stft = self._stft(x)
|
|
||||||
spectrogram = paddle.square(paddle.abs(stft))
|
|
||||||
return spectrogram
|
|
||||||
|
|
||||||
|
|
||||||
class MelSpectrogram(nn.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
sr: int=22050,
|
|
||||||
n_fft: int=512,
|
|
||||||
hop_length: Optional[int]=None,
|
|
||||||
win_length: Optional[int]=None,
|
|
||||||
window: str='hann',
|
|
||||||
center: bool=True,
|
|
||||||
pad_mode: str='reflect',
|
|
||||||
n_mels: int=64,
|
|
||||||
f_min: float=50.0,
|
|
||||||
f_max: Optional[float]=None,
|
|
||||||
htk: bool=False,
|
|
||||||
norm: Union[str, float]='slaney',
|
|
||||||
dtype: str=paddle.float32):
|
|
||||||
"""Compute the melspectrogram of a given signal, typically an audio waveform.
|
|
||||||
The melspectrogram is also known as filterbank or fbank feature in audio community.
|
|
||||||
It is computed by multiplying spectrogram with Mel filter bank matrix.
|
|
||||||
Parameters:
|
|
||||||
sr(int): the audio sample rate.
|
|
||||||
The default value is 22050.
|
|
||||||
n_fft(int): the number of frequency components of the discrete Fourier transform.
|
|
||||||
The default value is 2048,
|
|
||||||
hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
|
||||||
The default value is None.
|
|
||||||
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
|
||||||
The default value is None.
|
|
||||||
window(str): the name of the window function applied to the single before the Fourier transform.
|
|
||||||
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
|
||||||
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
|
||||||
The default value is 'hann'
|
|
||||||
center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
|
||||||
If False, frame t begins at x[t * hop_length]
|
|
||||||
The default value is True
|
|
||||||
pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
|
||||||
and 'constant'.
|
|
||||||
The default value is 'reflect'.
|
|
||||||
n_mels(int): the mel bins.
|
|
||||||
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
|
||||||
f_max(float): the upper cut-off frequency, above which the filter response is zeros.
|
|
||||||
htk(bool): whether to use HTK formula in computing fbank matrix.
|
|
||||||
norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
|
|
||||||
You can specify norm=1.0/2.0 to use customized p-norm normalization.
|
|
||||||
dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
|
|
||||||
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
|
|
||||||
"""
|
|
||||||
super(MelSpectrogram, self).__init__()
|
|
||||||
|
|
||||||
self._spectrogram = Spectrogram(
|
|
||||||
n_fft=n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=window,
|
|
||||||
center=center,
|
|
||||||
pad_mode=pad_mode,
|
|
||||||
dtype=dtype)
|
|
||||||
self.n_mels = n_mels
|
|
||||||
self.f_min = f_min
|
|
||||||
self.f_max = f_max
|
|
||||||
self.htk = htk
|
|
||||||
self.norm = norm
|
|
||||||
if f_max is None:
|
|
||||||
f_max = sr // 2
|
|
||||||
self.fbank_matrix = compute_fbank_matrix(
|
|
||||||
sr=sr,
|
|
||||||
n_fft=n_fft,
|
|
||||||
n_mels=n_mels,
|
|
||||||
f_min=f_min,
|
|
||||||
f_max=f_max,
|
|
||||||
htk=htk,
|
|
||||||
norm=norm,
|
|
||||||
dtype=dtype) # float64 for better numerical results
|
|
||||||
self.register_buffer('fbank_matrix', self.fbank_matrix)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
spect_feature = self._spectrogram(x)
|
|
||||||
mel_feature = paddle.matmul(self.fbank_matrix, spect_feature)
|
|
||||||
return mel_feature
|
|
||||||
|
|
||||||
|
|
||||||
class LogMelSpectrogram(nn.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
sr: int=22050,
|
|
||||||
n_fft: int=512,
|
|
||||||
hop_length: Optional[int]=None,
|
|
||||||
win_length: Optional[int]=None,
|
|
||||||
window: str='hann',
|
|
||||||
center: bool=True,
|
|
||||||
pad_mode: str='reflect',
|
|
||||||
n_mels: int=64,
|
|
||||||
f_min: float=50.0,
|
|
||||||
f_max: Optional[float]=None,
|
|
||||||
htk: bool=False,
|
|
||||||
norm: Union[str, float]='slaney',
|
|
||||||
ref_value: float=1.0,
|
|
||||||
amin: float=1e-10,
|
|
||||||
top_db: Optional[float]=None,
|
|
||||||
dtype: str=paddle.float32):
|
|
||||||
"""Compute log-mel-spectrogram(also known as LogFBank) feature of a given signal,
|
|
||||||
typically an audio waveform.
|
|
||||||
Parameters:
|
|
||||||
sr(int): the audio sample rate.
|
|
||||||
The default value is 22050.
|
|
||||||
n_fft(int): the number of frequency components of the discrete Fourier transform.
|
|
||||||
The default value is 2048,
|
|
||||||
hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
|
||||||
The default value is None.
|
|
||||||
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
|
||||||
The default value is None.
|
|
||||||
window(str): the name of the window function applied to the single before the Fourier transform.
|
|
||||||
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
|
||||||
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
|
||||||
The default value is 'hann'
|
|
||||||
center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
|
||||||
If False, frame t begins at x[t * hop_length]
|
|
||||||
The default value is True
|
|
||||||
pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
|
||||||
and 'constant'.
|
|
||||||
The default value is 'reflect'.
|
|
||||||
n_mels(int): the mel bins.
|
|
||||||
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
|
||||||
f_max(float): the upper cut-off frequency, above which the filter response is zeros.
|
|
||||||
ref_value(float): the reference value. If smaller than 1.0, the db level
|
|
||||||
htk(bool): whether to use HTK formula in computing fbank matrix.
|
|
||||||
norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
|
|
||||||
You can specify norm=1.0/2.0 to use customized p-norm normalization.
|
|
||||||
dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
|
|
||||||
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
|
|
||||||
amin(float): the minimum value of input magnitude, below which the input of the signal will be pulled up accordingly.
|
|
||||||
Otherwise, the db level is pushed down.
|
|
||||||
magnitude is clipped(to amin). For numerical stability, set amin to a larger value,
|
|
||||||
e.g., 1e-3.
|
|
||||||
top_db(float): the maximum db value of resulting spectrum, above which the
|
|
||||||
spectrum is clipped(to top_db).
|
|
||||||
"""
|
|
||||||
super(LogMelSpectrogram, self).__init__()
|
|
||||||
|
|
||||||
self._melspectrogram = MelSpectrogram(
|
|
||||||
sr=sr,
|
|
||||||
n_fft=n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
window=window,
|
|
||||||
center=center,
|
|
||||||
pad_mode=pad_mode,
|
|
||||||
n_mels=n_mels,
|
|
||||||
f_min=f_min,
|
|
||||||
f_max=f_max,
|
|
||||||
htk=htk,
|
|
||||||
norm=norm,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
self.ref_value = ref_value
|
|
||||||
self.amin = amin
|
|
||||||
self.top_db = top_db
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
mel_feature = self._melspectrogram(x)
|
|
||||||
log_mel_feature = power_to_db(
|
|
||||||
mel_feature,
|
|
||||||
ref_value=self.ref_value,
|
|
||||||
amin=self.amin,
|
|
||||||
top_db=self.top_db)
|
|
||||||
return log_mel_feature
|
|
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from . import compliance
|
||||||
|
from . import datasets
|
||||||
|
from . import features
|
||||||
|
from . import functional
|
||||||
|
from . import io
|
||||||
|
from . import metric
|
||||||
|
from . import sox_effects
|
||||||
|
from .backends import load
|
||||||
|
from .backends import save
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .soundfile_backend import depth_convert
|
||||||
|
from .soundfile_backend import load
|
||||||
|
from .soundfile_backend import normalize
|
||||||
|
from .soundfile_backend import resample
|
||||||
|
from .soundfile_backend import save
|
||||||
|
from .soundfile_backend import to_mono
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,638 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from torchaudio(https://github.com/pytorch/audio)
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
|
||||||
|
from ..functional import create_dct
|
||||||
|
from ..functional.window import get_window
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'spectrogram',
|
||||||
|
'fbank',
|
||||||
|
'mfcc',
|
||||||
|
]
|
||||||
|
|
||||||
|
# window types
|
||||||
|
HANNING = 'hann'
|
||||||
|
HAMMING = 'hamming'
|
||||||
|
POVEY = 'povey'
|
||||||
|
RECTANGULAR = 'rect'
|
||||||
|
BLACKMAN = 'blackman'
|
||||||
|
|
||||||
|
|
||||||
|
def _get_epsilon(dtype):
|
||||||
|
return paddle.to_tensor(1e-07, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _next_power_of_2(x: int) -> int:
|
||||||
|
return 1 if x == 0 else 2**(x - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_strided(waveform: Tensor,
|
||||||
|
window_size: int,
|
||||||
|
window_shift: int,
|
||||||
|
snip_edges: bool) -> Tensor:
|
||||||
|
assert waveform.dim() == 1
|
||||||
|
num_samples = waveform.shape[0]
|
||||||
|
|
||||||
|
if snip_edges:
|
||||||
|
if num_samples < window_size:
|
||||||
|
return paddle.empty((0, 0), dtype=waveform.dtype)
|
||||||
|
else:
|
||||||
|
m = 1 + (num_samples - window_size) // window_shift
|
||||||
|
else:
|
||||||
|
reversed_waveform = paddle.flip(waveform, [0])
|
||||||
|
m = (num_samples + (window_shift // 2)) // window_shift
|
||||||
|
pad = window_size // 2 - window_shift // 2
|
||||||
|
pad_right = reversed_waveform
|
||||||
|
if pad > 0:
|
||||||
|
pad_left = reversed_waveform[-pad:]
|
||||||
|
waveform = paddle.concat((pad_left, waveform, pad_right), axis=0)
|
||||||
|
else:
|
||||||
|
waveform = paddle.concat((waveform[-pad:], pad_right), axis=0)
|
||||||
|
|
||||||
|
return paddle.signal.frame(waveform, window_size, window_shift)[:, :m].T
|
||||||
|
|
||||||
|
|
||||||
|
def _feature_window_function(
|
||||||
|
window_type: str,
|
||||||
|
window_size: int,
|
||||||
|
blackman_coeff: float,
|
||||||
|
dtype: int, ) -> Tensor:
|
||||||
|
if window_type == HANNING:
|
||||||
|
return get_window('hann', window_size, fftbins=False, dtype=dtype)
|
||||||
|
elif window_type == HAMMING:
|
||||||
|
return get_window('hamming', window_size, fftbins=False, dtype=dtype)
|
||||||
|
elif window_type == POVEY:
|
||||||
|
return get_window(
|
||||||
|
'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
|
||||||
|
elif window_type == RECTANGULAR:
|
||||||
|
return paddle.ones([window_size], dtype=dtype)
|
||||||
|
elif window_type == BLACKMAN:
|
||||||
|
a = 2 * math.pi / (window_size - 1)
|
||||||
|
window_function = paddle.arange(window_size, dtype=dtype)
|
||||||
|
return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
|
||||||
|
(0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
|
||||||
|
).astype(dtype)
|
||||||
|
else:
|
||||||
|
raise Exception('Invalid window type ' + window_type)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_log_energy(strided_input: Tensor, epsilon: Tensor,
|
||||||
|
energy_floor: float) -> Tensor:
|
||||||
|
log_energy = paddle.maximum(strided_input.pow(2).sum(1), epsilon).log()
|
||||||
|
if energy_floor == 0.0:
|
||||||
|
return log_energy
|
||||||
|
return paddle.maximum(
|
||||||
|
log_energy,
|
||||||
|
paddle.to_tensor(math.log(energy_floor), dtype=strided_input.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_waveform_and_window_properties(
|
||||||
|
waveform: Tensor,
|
||||||
|
channel: int,
|
||||||
|
sr: int,
|
||||||
|
frame_shift: float,
|
||||||
|
frame_length: float,
|
||||||
|
round_to_power_of_two: bool,
|
||||||
|
preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]:
|
||||||
|
channel = max(channel, 0)
|
||||||
|
assert channel < waveform.shape[0], (
|
||||||
|
'Invalid channel {} for size {}'.format(channel, waveform.shape[0]))
|
||||||
|
waveform = waveform[channel, :] # size (n)
|
||||||
|
window_shift = int(
|
||||||
|
sr * frame_shift *
|
||||||
|
0.001) # pass frame_shift and frame_length in milliseconds
|
||||||
|
window_size = int(sr * frame_length * 0.001)
|
||||||
|
padded_window_size = _next_power_of_2(
|
||||||
|
window_size) if round_to_power_of_two else window_size
|
||||||
|
|
||||||
|
assert 2 <= window_size <= len(waveform), (
|
||||||
|
'choose a window size {} that is [2, {}]'.format(window_size,
|
||||||
|
len(waveform)))
|
||||||
|
assert 0 < window_shift, '`window_shift` must be greater than 0'
|
||||||
|
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
|
||||||
|
' use `round_to_power_of_two` or change `frame_length`'
|
||||||
|
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
|
||||||
|
assert sr > 0, '`sr` must be greater than zero'
|
||||||
|
return waveform, window_shift, window_size, padded_window_size
|
||||||
|
|
||||||
|
|
||||||
|
def _get_window(waveform: Tensor,
|
||||||
|
padded_window_size: int,
|
||||||
|
window_size: int,
|
||||||
|
window_shift: int,
|
||||||
|
window_type: str,
|
||||||
|
blackman_coeff: float,
|
||||||
|
snip_edges: bool,
|
||||||
|
raw_energy: bool,
|
||||||
|
energy_floor: float,
|
||||||
|
dither: float,
|
||||||
|
remove_dc_offset: bool,
|
||||||
|
preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]:
|
||||||
|
dtype = waveform.dtype
|
||||||
|
epsilon = _get_epsilon(dtype)
|
||||||
|
|
||||||
|
# (m, window_size)
|
||||||
|
strided_input = _get_strided(waveform, window_size, window_shift,
|
||||||
|
snip_edges)
|
||||||
|
|
||||||
|
if dither != 0.0:
|
||||||
|
x = paddle.maximum(epsilon,
|
||||||
|
paddle.rand(strided_input.shape, dtype=dtype))
|
||||||
|
rand_gauss = paddle.sqrt(-2 * x.log()) * paddle.cos(2 * math.pi * x)
|
||||||
|
strided_input = strided_input + rand_gauss * dither
|
||||||
|
|
||||||
|
if remove_dc_offset:
|
||||||
|
row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) # (m, 1)
|
||||||
|
strided_input = strided_input - row_means
|
||||||
|
|
||||||
|
if raw_energy:
|
||||||
|
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
||||||
|
energy_floor) # (m)
|
||||||
|
|
||||||
|
if preemphasis_coefficient != 0.0:
|
||||||
|
offset_strided_input = paddle.nn.functional.pad(
|
||||||
|
strided_input.unsqueeze(0), (1, 0),
|
||||||
|
data_format='NCL',
|
||||||
|
mode='replicate').squeeze(0) # (m, window_size + 1)
|
||||||
|
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :
|
||||||
|
-1]
|
||||||
|
|
||||||
|
window_function = _feature_window_function(
|
||||||
|
window_type, window_size, blackman_coeff,
|
||||||
|
dtype).unsqueeze(0) # (1, window_size)
|
||||||
|
strided_input = strided_input * window_function # (m, window_size)
|
||||||
|
|
||||||
|
# (m, padded_window_size)
|
||||||
|
if padded_window_size != window_size:
|
||||||
|
padding_right = padded_window_size - window_size
|
||||||
|
strided_input = paddle.nn.functional.pad(
|
||||||
|
strided_input.unsqueeze(0), (0, padding_right),
|
||||||
|
data_format='NCL',
|
||||||
|
mode='constant',
|
||||||
|
value=0).squeeze(0)
|
||||||
|
|
||||||
|
if not raw_energy:
|
||||||
|
signal_log_energy = _get_log_energy(strided_input, epsilon,
|
||||||
|
energy_floor) # size (m)
|
||||||
|
|
||||||
|
return strided_input, signal_log_energy
|
||||||
|
|
||||||
|
|
||||||
|
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
||||||
|
if subtract_mean:
|
||||||
|
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
|
||||||
|
tensor = tensor - col_means
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram(waveform: Tensor,
|
||||||
|
blackman_coeff: float=0.42,
|
||||||
|
channel: int=-1,
|
||||||
|
dither: float=0.0,
|
||||||
|
energy_floor: float=1.0,
|
||||||
|
frame_length: float=25.0,
|
||||||
|
frame_shift: float=10.0,
|
||||||
|
preemphasis_coefficient: float=0.97,
|
||||||
|
raw_energy: bool=True,
|
||||||
|
remove_dc_offset: bool=True,
|
||||||
|
round_to_power_of_two: bool=True,
|
||||||
|
sr: int=16000,
|
||||||
|
snip_edges: bool=True,
|
||||||
|
subtract_mean: bool=False,
|
||||||
|
window_type: str=POVEY) -> Tensor:
|
||||||
|
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
||||||
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||||
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||||
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||||
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||||
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||||
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||||
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||||
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||||
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. Defaults to True.
|
||||||
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||||
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
||||||
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||||
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||||
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A spectrogram tensor with shape (m, padded_window_size // 2 + 1) where m is the number of frames
|
||||||
|
depends on frame_length and frame_shift.
|
||||||
|
"""
|
||||||
|
dtype = waveform.dtype
|
||||||
|
epsilon = _get_epsilon(dtype)
|
||||||
|
|
||||||
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||||
|
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
||||||
|
preemphasis_coefficient)
|
||||||
|
|
||||||
|
strided_input, signal_log_energy = _get_window(
|
||||||
|
waveform, padded_window_size, window_size, window_shift, window_type,
|
||||||
|
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
||||||
|
remove_dc_offset, preemphasis_coefficient)
|
||||||
|
|
||||||
|
# (m, padded_window_size // 2 + 1, 2)
|
||||||
|
fft = paddle.fft.rfft(strided_input)
|
||||||
|
|
||||||
|
power_spectrum = paddle.maximum(
|
||||||
|
fft.abs().pow(2.), epsilon).log() # (m, padded_window_size // 2 + 1)
|
||||||
|
power_spectrum[:, 0] = signal_log_energy
|
||||||
|
|
||||||
|
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||||
|
return power_spectrum
|
||||||
|
|
||||||
|
|
||||||
|
def _inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||||
|
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
||||||
|
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_scale_scalar(freq: float) -> float:
|
||||||
|
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_scale(freq: Tensor) -> Tensor:
|
||||||
|
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||||
|
|
||||||
|
|
||||||
|
def _vtln_warp_freq(vtln_low_cutoff: float,
|
||||||
|
vtln_high_cutoff: float,
|
||||||
|
low_freq: float,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_warp_factor: float,
|
||||||
|
freq: Tensor) -> Tensor:
|
||||||
|
assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
|
||||||
|
assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
|
||||||
|
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||||
|
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||||
|
scale = 1.0 / vtln_warp_factor
|
||||||
|
Fl = scale * l
|
||||||
|
Fh = scale * h
|
||||||
|
assert l > low_freq and h < high_freq
|
||||||
|
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||||
|
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||||
|
res = paddle.empty_like(freq)
|
||||||
|
|
||||||
|
outside_low_high_freq = paddle.less_than(freq, paddle.to_tensor(low_freq)) \
|
||||||
|
| paddle.greater_than(freq, paddle.to_tensor(high_freq))
|
||||||
|
before_l = paddle.less_than(freq, paddle.to_tensor(l))
|
||||||
|
before_h = paddle.less_than(freq, paddle.to_tensor(h))
|
||||||
|
after_h = paddle.greater_equal(freq, paddle.to_tensor(h))
|
||||||
|
|
||||||
|
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||||
|
res[before_h] = scale * freq[before_h]
|
||||||
|
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||||
|
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _vtln_warp_mel_freq(vtln_low_cutoff: float,
|
||||||
|
vtln_high_cutoff: float,
|
||||||
|
low_freq,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_warp_factor: float,
|
||||||
|
mel_freq: Tensor) -> Tensor:
|
||||||
|
return _mel_scale(
|
||||||
|
_vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq,
|
||||||
|
vtln_warp_factor, _inverse_mel_scale(mel_freq)))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mel_banks(num_bins: int,
|
||||||
|
window_length_padded: int,
|
||||||
|
sample_freq: float,
|
||||||
|
low_freq: float,
|
||||||
|
high_freq: float,
|
||||||
|
vtln_low: float,
|
||||||
|
vtln_high: float,
|
||||||
|
vtln_warp_factor: float) -> Tuple[Tensor, Tensor]:
|
||||||
|
assert num_bins > 3, 'Must have at least 3 mel bins'
|
||||||
|
assert window_length_padded % 2 == 0
|
||||||
|
num_fft_bins = window_length_padded / 2
|
||||||
|
nyquist = 0.5 * sample_freq
|
||||||
|
|
||||||
|
if high_freq <= 0.0:
|
||||||
|
high_freq += nyquist
|
||||||
|
|
||||||
|
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
|
||||||
|
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
|
||||||
|
|
||||||
|
fft_bin_width = sample_freq / window_length_padded
|
||||||
|
mel_low_freq = _mel_scale_scalar(low_freq)
|
||||||
|
mel_high_freq = _mel_scale_scalar(high_freq)
|
||||||
|
|
||||||
|
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||||
|
|
||||||
|
if vtln_high < 0.0:
|
||||||
|
vtln_high += nyquist
|
||||||
|
|
||||||
|
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
|
||||||
|
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
|
||||||
|
('Bad values in options: vtln-low {} and vtln-high {}, versus '
|
||||||
|
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
|
||||||
|
|
||||||
|
bin = paddle.arange(num_bins).unsqueeze(1)
|
||||||
|
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
|
||||||
|
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
|
||||||
|
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
|
||||||
|
|
||||||
|
if vtln_warp_factor != 1.0:
|
||||||
|
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
|
||||||
|
vtln_warp_factor, left_mel)
|
||||||
|
center_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
|
||||||
|
high_freq, vtln_warp_factor,
|
||||||
|
center_mel)
|
||||||
|
right_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq,
|
||||||
|
high_freq, vtln_warp_factor, right_mel)
|
||||||
|
|
||||||
|
center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
|
||||||
|
# (1, num_fft_bins)
|
||||||
|
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
|
||||||
|
|
||||||
|
# (num_bins, num_fft_bins)
|
||||||
|
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||||
|
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||||
|
|
||||||
|
if vtln_warp_factor == 1.0:
|
||||||
|
bins = paddle.maximum(
|
||||||
|
paddle.zeros([1]), paddle.minimum(up_slope, down_slope))
|
||||||
|
else:
|
||||||
|
bins = paddle.zeros_like(up_slope)
|
||||||
|
up_idx = paddle.greater_than(mel, left_mel) & paddle.less_than(
|
||||||
|
mel, center_mel)
|
||||||
|
down_idx = paddle.greater_than(mel, center_mel) & paddle.less_than(
|
||||||
|
mel, right_mel)
|
||||||
|
bins[up_idx] = up_slope[up_idx]
|
||||||
|
bins[down_idx] = down_slope[down_idx]
|
||||||
|
|
||||||
|
return bins, center_freqs
|
||||||
|
|
||||||
|
|
||||||
|
def fbank(waveform: Tensor,
|
||||||
|
blackman_coeff: float=0.42,
|
||||||
|
channel: int=-1,
|
||||||
|
dither: float=0.0,
|
||||||
|
energy_floor: float=1.0,
|
||||||
|
frame_length: float=25.0,
|
||||||
|
frame_shift: float=10.0,
|
||||||
|
high_freq: float=0.0,
|
||||||
|
htk_compat: bool=False,
|
||||||
|
low_freq: float=20.0,
|
||||||
|
n_mels: int=23,
|
||||||
|
preemphasis_coefficient: float=0.97,
|
||||||
|
raw_energy: bool=True,
|
||||||
|
remove_dc_offset: bool=True,
|
||||||
|
round_to_power_of_two: bool=True,
|
||||||
|
sr: int=16000,
|
||||||
|
snip_edges: bool=True,
|
||||||
|
subtract_mean: bool=False,
|
||||||
|
use_energy: bool=False,
|
||||||
|
use_log_fbank: bool=True,
|
||||||
|
use_power: bool=True,
|
||||||
|
vtln_high: float=-500.0,
|
||||||
|
vtln_low: float=100.0,
|
||||||
|
vtln_warp: float=1.0,
|
||||||
|
window_type: str=POVEY) -> Tensor:
|
||||||
|
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
||||||
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||||
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||||
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||||
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||||
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||||
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||||
|
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
||||||
|
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
||||||
|
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
||||||
|
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
||||||
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||||
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||||
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. Defaults to True.
|
||||||
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||||
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
||||||
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||||
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||||
|
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
||||||
|
use_log_fbank (bool, optional): Return log fbank when it is set True. Defaults to True.
|
||||||
|
use_power (bool, optional): Whether to use power instead of magnitude. Defaults to True.
|
||||||
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
||||||
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
||||||
|
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
||||||
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A filter banks tensor with shape (m, n_mels).
|
||||||
|
"""
|
||||||
|
dtype = waveform.dtype
|
||||||
|
|
||||||
|
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||||
|
waveform, channel, sr, frame_shift, frame_length, round_to_power_of_two,
|
||||||
|
preemphasis_coefficient)
|
||||||
|
|
||||||
|
strided_input, signal_log_energy = _get_window(
|
||||||
|
waveform, padded_window_size, window_size, window_shift, window_type,
|
||||||
|
blackman_coeff, snip_edges, raw_energy, energy_floor, dither,
|
||||||
|
remove_dc_offset, preemphasis_coefficient)
|
||||||
|
|
||||||
|
# (m, padded_window_size // 2 + 1)
|
||||||
|
spectrum = paddle.fft.rfft(strided_input).abs()
|
||||||
|
if use_power:
|
||||||
|
spectrum = spectrum.pow(2.)
|
||||||
|
|
||||||
|
# (n_mels, padded_window_size // 2)
|
||||||
|
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
|
||||||
|
high_freq, vtln_low, vtln_high, vtln_warp)
|
||||||
|
mel_energies = mel_energies.astype(dtype)
|
||||||
|
|
||||||
|
# (n_mels, padded_window_size // 2 + 1)
|
||||||
|
mel_energies = paddle.nn.functional.pad(
|
||||||
|
mel_energies.unsqueeze(0), (0, 1),
|
||||||
|
data_format='NCL',
|
||||||
|
mode='constant',
|
||||||
|
value=0).squeeze(0)
|
||||||
|
|
||||||
|
# (m, n_mels)
|
||||||
|
mel_energies = paddle.mm(spectrum, mel_energies.T)
|
||||||
|
if use_log_fbank:
|
||||||
|
mel_energies = paddle.maximum(mel_energies, _get_epsilon(dtype)).log()
|
||||||
|
|
||||||
|
if use_energy:
|
||||||
|
signal_log_energy = signal_log_energy.unsqueeze(1)
|
||||||
|
if htk_compat:
|
||||||
|
mel_energies = paddle.concat(
|
||||||
|
(mel_energies, signal_log_energy), axis=1)
|
||||||
|
else:
|
||||||
|
mel_energies = paddle.concat(
|
||||||
|
(signal_log_energy, mel_energies), axis=1)
|
||||||
|
|
||||||
|
# (m, n_mels + 1)
|
||||||
|
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||||
|
return mel_energies
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dct_matrix(n_mfcc: int, n_mels: int) -> Tensor:
|
||||||
|
dct_matrix = create_dct(n_mels, n_mels, 'ortho')
|
||||||
|
dct_matrix[:, 0] = math.sqrt(1 / float(n_mels))
|
||||||
|
dct_matrix = dct_matrix[:, :n_mfcc] # (n_mels, n_mfcc)
|
||||||
|
return dct_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lifter_coeffs(n_mfcc: int, cepstral_lifter: float) -> Tensor:
|
||||||
|
i = paddle.arange(n_mfcc)
|
||||||
|
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i /
|
||||||
|
cepstral_lifter)
|
||||||
|
|
||||||
|
|
||||||
|
def mfcc(waveform: Tensor,
|
||||||
|
blackman_coeff: float=0.42,
|
||||||
|
cepstral_lifter: float=22.0,
|
||||||
|
channel: int=-1,
|
||||||
|
dither: float=0.0,
|
||||||
|
energy_floor: float=1.0,
|
||||||
|
frame_length: float=25.0,
|
||||||
|
frame_shift: float=10.0,
|
||||||
|
high_freq: float=0.0,
|
||||||
|
htk_compat: bool=False,
|
||||||
|
low_freq: float=20.0,
|
||||||
|
n_mfcc: int=13,
|
||||||
|
n_mels: int=23,
|
||||||
|
preemphasis_coefficient: float=0.97,
|
||||||
|
raw_energy: bool=True,
|
||||||
|
remove_dc_offset: bool=True,
|
||||||
|
round_to_power_of_two: bool=True,
|
||||||
|
sr: int=16000,
|
||||||
|
snip_edges: bool=True,
|
||||||
|
subtract_mean: bool=False,
|
||||||
|
use_energy: bool=False,
|
||||||
|
vtln_high: float=-500.0,
|
||||||
|
vtln_low: float=100.0,
|
||||||
|
vtln_warp: float=1.0,
|
||||||
|
window_type: str=POVEY) -> Tensor:
|
||||||
|
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is
|
||||||
|
identical to Kaldi's.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): A waveform tensor with shape [C, T].
|
||||||
|
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
|
||||||
|
cepstral_lifter (float, optional): Scaling of output mfccs. Defaults to 22.0.
|
||||||
|
channel (int, optional): Select the channel of waveform. Defaults to -1.
|
||||||
|
dither (float, optional): Dithering constant . Defaults to 0.0.
|
||||||
|
energy_floor (float, optional): Floor on energy of the output Spectrogram. Defaults to 1.0.
|
||||||
|
frame_length (float, optional): Frame length in milliseconds. Defaults to 25.0.
|
||||||
|
frame_shift (float, optional): Shift between adjacent frames in milliseconds. Defaults to 10.0.
|
||||||
|
high_freq (float, optional): The upper cut-off frequency. Defaults to 0.0.
|
||||||
|
htk_compat (bool, optional): Put energy to the last when it is set True. Defaults to False.
|
||||||
|
low_freq (float, optional): The lower cut-off frequency. Defaults to 20.0.
|
||||||
|
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 13.
|
||||||
|
n_mels (int, optional): Number of output mel bins. Defaults to 23.
|
||||||
|
preemphasis_coefficient (float, optional): Preemphasis coefficient for input waveform. Defaults to 0.97.
|
||||||
|
raw_energy (bool, optional): Whether to compute before preemphasis and windowing. Defaults to True.
|
||||||
|
remove_dc_offset (bool, optional): Whether to subtract mean from waveform on frames. Defaults to True.
|
||||||
|
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||||
|
to FFT. Defaults to True.
|
||||||
|
sr (int, optional): Sample rate of input waveform. Defaults to 16000.
|
||||||
|
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
|
||||||
|
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
|
||||||
|
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
|
||||||
|
use_energy (bool, optional): Add an dimension with energy of spectrogram to the output. Defaults to False.
|
||||||
|
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
|
||||||
|
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
|
||||||
|
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
|
||||||
|
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A mel frequency cepstral coefficients tensor with shape (m, n_mfcc).
|
||||||
|
"""
|
||||||
|
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
|
||||||
|
n_mfcc, n_mels)
|
||||||
|
|
||||||
|
dtype = waveform.dtype
|
||||||
|
|
||||||
|
# (m, n_mels + use_energy)
|
||||||
|
feature = fbank(
|
||||||
|
waveform=waveform,
|
||||||
|
blackman_coeff=blackman_coeff,
|
||||||
|
channel=channel,
|
||||||
|
dither=dither,
|
||||||
|
energy_floor=energy_floor,
|
||||||
|
frame_length=frame_length,
|
||||||
|
frame_shift=frame_shift,
|
||||||
|
high_freq=high_freq,
|
||||||
|
htk_compat=htk_compat,
|
||||||
|
low_freq=low_freq,
|
||||||
|
n_mels=n_mels,
|
||||||
|
preemphasis_coefficient=preemphasis_coefficient,
|
||||||
|
raw_energy=raw_energy,
|
||||||
|
remove_dc_offset=remove_dc_offset,
|
||||||
|
round_to_power_of_two=round_to_power_of_two,
|
||||||
|
sr=sr,
|
||||||
|
snip_edges=snip_edges,
|
||||||
|
subtract_mean=False,
|
||||||
|
use_energy=use_energy,
|
||||||
|
use_log_fbank=True,
|
||||||
|
use_power=True,
|
||||||
|
vtln_high=vtln_high,
|
||||||
|
vtln_low=vtln_low,
|
||||||
|
vtln_warp=vtln_warp,
|
||||||
|
window_type=window_type)
|
||||||
|
|
||||||
|
if use_energy:
|
||||||
|
# (m)
|
||||||
|
signal_log_energy = feature[:, n_mels if htk_compat else 0]
|
||||||
|
mel_offset = int(not htk_compat)
|
||||||
|
feature = feature[:, mel_offset:(n_mels + mel_offset)]
|
||||||
|
|
||||||
|
# (n_mels, n_mfcc)
|
||||||
|
dct_matrix = _get_dct_matrix(n_mfcc, n_mels).astype(dtype=dtype)
|
||||||
|
|
||||||
|
# (m, n_mfcc)
|
||||||
|
feature = feature.matmul(dct_matrix)
|
||||||
|
|
||||||
|
if cepstral_lifter != 0.0:
|
||||||
|
# (1, n_mfcc)
|
||||||
|
lifter_coeffs = _get_lifter_coeffs(n_mfcc, cepstral_lifter).unsqueeze(0)
|
||||||
|
feature *= lifter_coeffs.astype(dtype=dtype)
|
||||||
|
|
||||||
|
if use_energy:
|
||||||
|
feature[:, 0] = signal_log_energy
|
||||||
|
|
||||||
|
if htk_compat:
|
||||||
|
energy = feature[:, 0].unsqueeze(1) # (m, 1)
|
||||||
|
feature = feature[:, 1:] # (m, n_mfcc - 1)
|
||||||
|
if not use_energy:
|
||||||
|
energy *= math.sqrt(2)
|
||||||
|
|
||||||
|
feature = paddle.concat((feature, energy), axis=1)
|
||||||
|
|
||||||
|
feature = _subtract_column_mean(feature, subtract_mean)
|
||||||
|
return feature
|
@ -0,0 +1,350 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from ..functional import compute_fbank_matrix
|
||||||
|
from ..functional import create_dct
|
||||||
|
from ..functional import power_to_db
|
||||||
|
from ..functional.window import get_window
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Spectrogram',
|
||||||
|
'MelSpectrogram',
|
||||||
|
'LogMelSpectrogram',
|
||||||
|
'MFCC',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Spectrogram(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
n_fft: int=512,
|
||||||
|
hop_length: Optional[int]=None,
|
||||||
|
win_length: Optional[int]=None,
|
||||||
|
window: str='hann',
|
||||||
|
power: float=2.0,
|
||||||
|
center: bool=True,
|
||||||
|
pad_mode: str='reflect',
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute spectrogram of a given signal, typically an audio waveform.
|
||||||
|
The spectorgram is defined as the complex norm of the short-time
|
||||||
|
Fourier transformation.
|
||||||
|
Parameters:
|
||||||
|
n_fft (int): the number of frequency components of the discrete Fourier transform.
|
||||||
|
The default value is 2048,
|
||||||
|
hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
||||||
|
The default value is None.
|
||||||
|
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
||||||
|
The default value is None.
|
||||||
|
window (str): the name of the window function applied to the single before the Fourier transform.
|
||||||
|
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
||||||
|
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
||||||
|
The default value is 'hann'
|
||||||
|
power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
|
||||||
|
center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
||||||
|
If False, frame t begins at x[t * hop_length]
|
||||||
|
The default value is True
|
||||||
|
pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
||||||
|
and 'constant'. The default value is 'reflect'.
|
||||||
|
dtype (str): the data type of input and window.
|
||||||
|
Notes:
|
||||||
|
The Spectrogram transform relies on STFT transform to compute the spectrogram.
|
||||||
|
By default, the weights are not learnable. To fine-tune the Fourier coefficients,
|
||||||
|
set stop_gradient=False before training.
|
||||||
|
For more information, see STFT().
|
||||||
|
"""
|
||||||
|
super(Spectrogram, self).__init__()
|
||||||
|
|
||||||
|
assert power > 0, 'Power of spectrogram must be > 0.'
|
||||||
|
self.power = power
|
||||||
|
|
||||||
|
if win_length is None:
|
||||||
|
win_length = n_fft
|
||||||
|
|
||||||
|
self.fft_window = get_window(
|
||||||
|
window, win_length, fftbins=True, dtype=dtype)
|
||||||
|
self._stft = partial(
|
||||||
|
paddle.signal.stft,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=self.fft_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode)
|
||||||
|
self.register_buffer('fft_window', self.fft_window)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
stft = self._stft(x)
|
||||||
|
spectrogram = paddle.pow(paddle.abs(stft), self.power)
|
||||||
|
return spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogram(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
sr: int=22050,
|
||||||
|
n_fft: int=512,
|
||||||
|
hop_length: Optional[int]=None,
|
||||||
|
win_length: Optional[int]=None,
|
||||||
|
window: str='hann',
|
||||||
|
power: float=2.0,
|
||||||
|
center: bool=True,
|
||||||
|
pad_mode: str='reflect',
|
||||||
|
n_mels: int=64,
|
||||||
|
f_min: float=50.0,
|
||||||
|
f_max: Optional[float]=None,
|
||||||
|
htk: bool=False,
|
||||||
|
norm: Union[str, float]='slaney',
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute the melspectrogram of a given signal, typically an audio waveform.
|
||||||
|
The melspectrogram is also known as filterbank or fbank feature in audio community.
|
||||||
|
It is computed by multiplying spectrogram with Mel filter bank matrix.
|
||||||
|
Parameters:
|
||||||
|
sr(int): the audio sample rate.
|
||||||
|
The default value is 22050.
|
||||||
|
n_fft(int): the number of frequency components of the discrete Fourier transform.
|
||||||
|
The default value is 2048,
|
||||||
|
hop_length(int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
||||||
|
The default value is None.
|
||||||
|
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
||||||
|
The default value is None.
|
||||||
|
window(str): the name of the window function applied to the single before the Fourier transform.
|
||||||
|
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
||||||
|
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
||||||
|
The default value is 'hann'
|
||||||
|
power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
|
||||||
|
center(bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
||||||
|
If False, frame t begins at x[t * hop_length]
|
||||||
|
The default value is True
|
||||||
|
pad_mode(str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
||||||
|
and 'constant'.
|
||||||
|
The default value is 'reflect'.
|
||||||
|
n_mels(int): the mel bins.
|
||||||
|
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
||||||
|
f_max(float): the upper cut-off frequency, above which the filter response is zeros.
|
||||||
|
htk(bool): whether to use HTK formula in computing fbank matrix.
|
||||||
|
norm(str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
|
||||||
|
You can specify norm=1.0/2.0 to use customized p-norm normalization.
|
||||||
|
dtype(str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
|
||||||
|
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
|
||||||
|
"""
|
||||||
|
super(MelSpectrogram, self).__init__()
|
||||||
|
|
||||||
|
self._spectrogram = Spectrogram(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
power=power,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
dtype=dtype)
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max
|
||||||
|
self.htk = htk
|
||||||
|
self.norm = norm
|
||||||
|
if f_max is None:
|
||||||
|
f_max = sr // 2
|
||||||
|
self.fbank_matrix = compute_fbank_matrix(
|
||||||
|
sr=sr,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_mels=n_mels,
|
||||||
|
f_min=f_min,
|
||||||
|
f_max=f_max,
|
||||||
|
htk=htk,
|
||||||
|
norm=norm,
|
||||||
|
dtype=dtype) # float64 for better numerical results
|
||||||
|
self.register_buffer('fbank_matrix', self.fbank_matrix)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
spect_feature = self._spectrogram(x)
|
||||||
|
mel_feature = paddle.matmul(self.fbank_matrix, spect_feature)
|
||||||
|
return mel_feature
|
||||||
|
|
||||||
|
|
||||||
|
class LogMelSpectrogram(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
sr: int=22050,
|
||||||
|
n_fft: int=512,
|
||||||
|
hop_length: Optional[int]=None,
|
||||||
|
win_length: Optional[int]=None,
|
||||||
|
window: str='hann',
|
||||||
|
power: float=2.0,
|
||||||
|
center: bool=True,
|
||||||
|
pad_mode: str='reflect',
|
||||||
|
n_mels: int=64,
|
||||||
|
f_min: float=50.0,
|
||||||
|
f_max: Optional[float]=None,
|
||||||
|
htk: bool=False,
|
||||||
|
norm: Union[str, float]='slaney',
|
||||||
|
ref_value: float=1.0,
|
||||||
|
amin: float=1e-10,
|
||||||
|
top_db: Optional[float]=None,
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute log-mel-spectrogram(also known as LogFBank) feature of a given signal,
|
||||||
|
typically an audio waveform.
|
||||||
|
Parameters:
|
||||||
|
sr (int): the audio sample rate.
|
||||||
|
The default value is 22050.
|
||||||
|
n_fft (int): the number of frequency components of the discrete Fourier transform.
|
||||||
|
The default value is 2048,
|
||||||
|
hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
||||||
|
The default value is None.
|
||||||
|
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
||||||
|
The default value is None.
|
||||||
|
window (str): the name of the window function applied to the single before the Fourier transform.
|
||||||
|
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
||||||
|
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
||||||
|
The default value is 'hann'
|
||||||
|
center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
||||||
|
If False, frame t begins at x[t * hop_length]
|
||||||
|
The default value is True
|
||||||
|
pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
||||||
|
and 'constant'.
|
||||||
|
The default value is 'reflect'.
|
||||||
|
n_mels (int): the mel bins.
|
||||||
|
f_min (float): the lower cut-off frequency, below which the filter response is zero.
|
||||||
|
f_max (float): the upper cut-off frequency, above which the filter response is zeros.
|
||||||
|
htk (bool): whether to use HTK formula in computing fbank matrix.
|
||||||
|
norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
|
||||||
|
You can specify norm=1.0/2.0 to use customized p-norm normalization.
|
||||||
|
ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down.
|
||||||
|
amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin).
|
||||||
|
top_db (float): the maximum db value of resulting spectrum, above which the
|
||||||
|
spectrum is clipped(to top_db).
|
||||||
|
dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
|
||||||
|
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
|
||||||
|
"""
|
||||||
|
super(LogMelSpectrogram, self).__init__()
|
||||||
|
|
||||||
|
self._melspectrogram = MelSpectrogram(
|
||||||
|
sr=sr,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
power=power,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
n_mels=n_mels,
|
||||||
|
f_min=f_min,
|
||||||
|
f_max=f_max,
|
||||||
|
htk=htk,
|
||||||
|
norm=norm,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
self.ref_value = ref_value
|
||||||
|
self.amin = amin
|
||||||
|
self.top_db = top_db
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mel_feature = self._melspectrogram(x)
|
||||||
|
log_mel_feature = power_to_db(
|
||||||
|
mel_feature,
|
||||||
|
ref_value=self.ref_value,
|
||||||
|
amin=self.amin,
|
||||||
|
top_db=self.top_db)
|
||||||
|
return log_mel_feature
|
||||||
|
|
||||||
|
|
||||||
|
class MFCC(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
sr: int=22050,
|
||||||
|
n_mfcc: int=40,
|
||||||
|
n_fft: int=512,
|
||||||
|
hop_length: Optional[int]=None,
|
||||||
|
win_length: Optional[int]=None,
|
||||||
|
window: str='hann',
|
||||||
|
power: float=2.0,
|
||||||
|
center: bool=True,
|
||||||
|
pad_mode: str='reflect',
|
||||||
|
n_mels: int=64,
|
||||||
|
f_min: float=50.0,
|
||||||
|
f_max: Optional[float]=None,
|
||||||
|
htk: bool=False,
|
||||||
|
norm: Union[str, float]='slaney',
|
||||||
|
ref_value: float=1.0,
|
||||||
|
amin: float=1e-10,
|
||||||
|
top_db: Optional[float]=None,
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sr(int): the audio sample rate.
|
||||||
|
The default value is 22050.
|
||||||
|
n_mfcc (int, optional): Number of cepstra in MFCC. Defaults to 40.
|
||||||
|
n_fft (int): the number of frequency components of the discrete Fourier transform.
|
||||||
|
The default value is 2048,
|
||||||
|
hop_length (int|None): the hop length of the short time FFT. If None, it is set to win_length//4.
|
||||||
|
The default value is None.
|
||||||
|
win_length: the window length of the short time FFt. If None, it is set to same as n_fft.
|
||||||
|
The default value is None.
|
||||||
|
window (str): the name of the window function applied to the single before the Fourier transform.
|
||||||
|
The folllowing window names are supported: 'hamming','hann','kaiser','gaussian',
|
||||||
|
'exponential','triang','bohman','blackman','cosine','tukey','taylor'.
|
||||||
|
The default value is 'hann'
|
||||||
|
power (float): Exponent for the magnitude spectrogram. The default value is 2.0.
|
||||||
|
center (bool): if True, the signal is padded so that frame t is centered at x[t * hop_length].
|
||||||
|
If False, frame t begins at x[t * hop_length]
|
||||||
|
The default value is True
|
||||||
|
pad_mode (str): the mode to pad the signal if necessary. The supported modes are 'reflect'
|
||||||
|
and 'constant'.
|
||||||
|
The default value is 'reflect'.
|
||||||
|
n_mels (int): the mel bins.
|
||||||
|
f_min (float): the lower cut-off frequency, below which the filter response is zero.
|
||||||
|
f_max (float): the upper cut-off frequency, above which the filter response is zeros.
|
||||||
|
htk (bool): whether to use HTK formula in computing fbank matrix.
|
||||||
|
norm (str|float): the normalization type in computing fbank matrix. Slaney-style is used by default.
|
||||||
|
You can specify norm=1.0/2.0 to use customized p-norm normalization.
|
||||||
|
ref_value (float): the reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down.
|
||||||
|
amin (float): the minimum value of input magnitude, below which the input magnitude is clipped(to amin).
|
||||||
|
top_db (float): the maximum db value of resulting spectrum, above which the
|
||||||
|
spectrum is clipped(to top_db).
|
||||||
|
dtype (str): the datatype of fbank matrix used in the transform. Use float64 to increase numerical
|
||||||
|
accuracy. Note that the final transform will be conducted in float32 regardless of dtype of fbank matrix.
|
||||||
|
"""
|
||||||
|
super(MFCC, self).__init__()
|
||||||
|
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
|
||||||
|
n_mfcc, n_mels)
|
||||||
|
self._log_melspectrogram = LogMelSpectrogram(
|
||||||
|
sr=sr,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
power=power,
|
||||||
|
center=center,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
n_mels=n_mels,
|
||||||
|
f_min=f_min,
|
||||||
|
f_max=f_max,
|
||||||
|
htk=htk,
|
||||||
|
norm=norm,
|
||||||
|
ref_value=ref_value,
|
||||||
|
amin=amin,
|
||||||
|
top_db=top_db,
|
||||||
|
dtype=dtype)
|
||||||
|
self.dct_matrix = create_dct(n_mfcc=n_mfcc, n_mels=n_mels, dtype=dtype)
|
||||||
|
self.register_buffer('dct_matrix', self.dct_matrix)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
log_mel_feature = self._log_melspectrogram(x)
|
||||||
|
mfcc = paddle.matmul(
|
||||||
|
log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose(
|
||||||
|
(0, 2, 1)) # (B, n_mels, L)
|
||||||
|
return mfcc
|
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .functional import compute_fbank_matrix
|
||||||
|
from .functional import create_dct
|
||||||
|
from .functional import fft_frequencies
|
||||||
|
from .functional import hz_to_mel
|
||||||
|
from .functional import mel_frequencies
|
||||||
|
from .functional import mel_to_hz
|
||||||
|
from .functional import power_to_db
|
@ -0,0 +1,265 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Modified from librosa(https://github.com/librosa/librosa)
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'hz_to_mel',
|
||||||
|
'mel_to_hz',
|
||||||
|
'mel_frequencies',
|
||||||
|
'fft_frequencies',
|
||||||
|
'compute_fbank_matrix',
|
||||||
|
'power_to_db',
|
||||||
|
'create_dct',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def hz_to_mel(freq: Union[paddle.Tensor, float],
|
||||||
|
htk: bool=False) -> Union[paddle.Tensor, float]:
|
||||||
|
"""Convert Hz to Mels.
|
||||||
|
Parameters:
|
||||||
|
freq: the input tensor of arbitrary shape, or a single floating point number.
|
||||||
|
htk: use HTK formula to do the conversion.
|
||||||
|
The default value is False.
|
||||||
|
Returns:
|
||||||
|
The frequencies represented in Mel-scale.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if htk:
|
||||||
|
if isinstance(freq, paddle.Tensor):
|
||||||
|
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
|
||||||
|
else:
|
||||||
|
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
||||||
|
|
||||||
|
# Fill in the linear part
|
||||||
|
f_min = 0.0
|
||||||
|
f_sp = 200.0 / 3
|
||||||
|
|
||||||
|
mels = (freq - f_min) / f_sp
|
||||||
|
|
||||||
|
# Fill in the log-scale part
|
||||||
|
|
||||||
|
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||||
|
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||||
|
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||||
|
|
||||||
|
if isinstance(freq, paddle.Tensor):
|
||||||
|
target = min_log_mel + paddle.log(
|
||||||
|
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
|
||||||
|
mask = (freq > min_log_hz).astype(freq.dtype)
|
||||||
|
mels = target * mask + mels * (
|
||||||
|
1 - mask) # will replace by masked_fill OP in future
|
||||||
|
else:
|
||||||
|
if freq >= min_log_hz:
|
||||||
|
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
|
||||||
|
|
||||||
|
return mels
|
||||||
|
|
||||||
|
|
||||||
|
def mel_to_hz(mel: Union[float, paddle.Tensor],
|
||||||
|
htk: bool=False) -> Union[float, paddle.Tensor]:
|
||||||
|
"""Convert mel bin numbers to frequencies.
|
||||||
|
Parameters:
|
||||||
|
mel: the mel frequency represented as a tensor of arbitrary shape, or a floating point number.
|
||||||
|
htk: use HTK formula to do the conversion.
|
||||||
|
Returns:
|
||||||
|
The frequencies represented in hz.
|
||||||
|
"""
|
||||||
|
if htk:
|
||||||
|
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
|
||||||
|
|
||||||
|
f_min = 0.0
|
||||||
|
f_sp = 200.0 / 3
|
||||||
|
freqs = f_min + f_sp * mel
|
||||||
|
# And now the nonlinear scale
|
||||||
|
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||||
|
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||||
|
logstep = math.log(6.4) / 27.0 # step size for log region
|
||||||
|
if isinstance(mel, paddle.Tensor):
|
||||||
|
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
|
||||||
|
mask = (mel > min_log_mel).astype(mel.dtype)
|
||||||
|
freqs = target * mask + freqs * (
|
||||||
|
1 - mask) # will replace by masked_fill OP in future
|
||||||
|
else:
|
||||||
|
if mel >= min_log_mel:
|
||||||
|
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def mel_frequencies(n_mels: int=64,
|
||||||
|
f_min: float=0.0,
|
||||||
|
f_max: float=11025.0,
|
||||||
|
htk: bool=False,
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute mel frequencies.
|
||||||
|
Parameters:
|
||||||
|
n_mels(int): number of Mel bins.
|
||||||
|
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
||||||
|
f_max(float): the upper cut-off frequency, above which the filter response is zero.
|
||||||
|
htk(bool): whether to use htk formula.
|
||||||
|
dtype(str): the datatype of the return frequencies.
|
||||||
|
Returns:
|
||||||
|
The frequencies represented in Mel-scale
|
||||||
|
"""
|
||||||
|
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||||
|
min_mel = hz_to_mel(f_min, htk=htk)
|
||||||
|
max_mel = hz_to_mel(f_max, htk=htk)
|
||||||
|
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
|
||||||
|
freqs = mel_to_hz(mels, htk=htk)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def fft_frequencies(sr: int, n_fft: int, dtype: str=paddle.float32):
|
||||||
|
"""Compute fourier frequencies.
|
||||||
|
Parameters:
|
||||||
|
sr(int): the audio sample rate.
|
||||||
|
n_fft(float): the number of fft bins.
|
||||||
|
dtype(str): the datatype of the return frequencies.
|
||||||
|
Returns:
|
||||||
|
The frequencies represented in hz.
|
||||||
|
"""
|
||||||
|
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_matrix(sr: int,
|
||||||
|
n_fft: int,
|
||||||
|
n_mels: int=64,
|
||||||
|
f_min: float=0.0,
|
||||||
|
f_max: Optional[float]=None,
|
||||||
|
htk: bool=False,
|
||||||
|
norm: Union[str, float]='slaney',
|
||||||
|
dtype: str=paddle.float32):
|
||||||
|
"""Compute fbank matrix.
|
||||||
|
Parameters:
|
||||||
|
sr(int): the audio sample rate.
|
||||||
|
n_fft(int): the number of fft bins.
|
||||||
|
n_mels(int): the number of Mel bins.
|
||||||
|
f_min(float): the lower cut-off frequency, below which the filter response is zero.
|
||||||
|
f_max(float): the upper cut-off frequency, above which the filter response is zero.
|
||||||
|
htk: whether to use htk formula.
|
||||||
|
return_complex(bool): whether to return complex matrix. If True, the matrix will
|
||||||
|
be complex type. Otherwise, the real and image part will be stored in the last
|
||||||
|
axis of returned tensor.
|
||||||
|
dtype(str): the datatype of the returned fbank matrix.
|
||||||
|
Returns:
|
||||||
|
The fbank matrix of shape (n_mels, int(1+n_fft//2)).
|
||||||
|
Shape:
|
||||||
|
output: (n_mels, int(1+n_fft//2))
|
||||||
|
"""
|
||||||
|
|
||||||
|
if f_max is None:
|
||||||
|
f_max = float(sr) / 2
|
||||||
|
|
||||||
|
# Initialize the weights
|
||||||
|
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
||||||
|
|
||||||
|
# Center freqs of each FFT bin
|
||||||
|
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
|
||||||
|
|
||||||
|
# 'Center freqs' of mel bands - uniformly spaced between limits
|
||||||
|
mel_f = mel_frequencies(
|
||||||
|
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
|
||||||
|
|
||||||
|
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
|
||||||
|
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
|
||||||
|
#ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||||
|
|
||||||
|
for i in range(n_mels):
|
||||||
|
# lower and upper slopes for all bins
|
||||||
|
lower = -ramps[i] / fdiff[i]
|
||||||
|
upper = ramps[i + 2] / fdiff[i + 1]
|
||||||
|
|
||||||
|
# .. then intersect them with each other and zero
|
||||||
|
weights[i] = paddle.maximum(
|
||||||
|
paddle.zeros_like(lower), paddle.minimum(lower, upper))
|
||||||
|
|
||||||
|
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||||
|
if norm == 'slaney':
|
||||||
|
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
|
||||||
|
weights *= enorm.unsqueeze(1)
|
||||||
|
elif isinstance(norm, int) or isinstance(norm, float):
|
||||||
|
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def power_to_db(magnitude: paddle.Tensor,
|
||||||
|
ref_value: float=1.0,
|
||||||
|
amin: float=1e-10,
|
||||||
|
top_db: Optional[float]=None) -> paddle.Tensor:
|
||||||
|
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units.
|
||||||
|
The function computes the scaling ``10 * log10(x / ref)`` in a numerically
|
||||||
|
stable way.
|
||||||
|
Parameters:
|
||||||
|
magnitude(Tensor): the input magnitude tensor of any shape.
|
||||||
|
ref_value(float): the reference value. If smaller than 1.0, the db level
|
||||||
|
of the signal will be pulled up accordingly. Otherwise, the db level
|
||||||
|
is pushed down.
|
||||||
|
amin(float): the minimum value of input magnitude, below which the input
|
||||||
|
magnitude is clipped(to amin).
|
||||||
|
top_db(float): the maximum db value of resulting spectrum, above which the
|
||||||
|
spectrum is clipped(to top_db).
|
||||||
|
Returns:
|
||||||
|
The spectrogram in log-scale.
|
||||||
|
shape:
|
||||||
|
input: any shape
|
||||||
|
output: same as input
|
||||||
|
"""
|
||||||
|
if amin <= 0:
|
||||||
|
raise Exception("amin must be strictly positive")
|
||||||
|
|
||||||
|
if ref_value <= 0:
|
||||||
|
raise Exception("ref_value must be strictly positive")
|
||||||
|
|
||||||
|
ones = paddle.ones_like(magnitude)
|
||||||
|
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, magnitude))
|
||||||
|
log_spec -= 10.0 * math.log10(max(ref_value, amin))
|
||||||
|
|
||||||
|
if top_db is not None:
|
||||||
|
if top_db < 0:
|
||||||
|
raise Exception("top_db must be non-negative")
|
||||||
|
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
|
||||||
|
|
||||||
|
return log_spec
|
||||||
|
|
||||||
|
|
||||||
|
def create_dct(n_mfcc: int,
|
||||||
|
n_mels: int,
|
||||||
|
norm: Optional[str]='ortho',
|
||||||
|
dtype: Optional[str]=paddle.float32) -> paddle.Tensor:
|
||||||
|
"""Create a discrete cosine transform(DCT) matrix.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
n_mfcc (int): Number of mel frequency cepstral coefficients.
|
||||||
|
n_mels (int): Number of mel filterbanks.
|
||||||
|
norm (str, optional): Normalizaiton type. Defaults to 'ortho'.
|
||||||
|
Returns:
|
||||||
|
Tensor: The DCT matrix with shape (n_mels, n_mfcc).
|
||||||
|
"""
|
||||||
|
n = paddle.arange(n_mels, dtype=dtype)
|
||||||
|
k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
|
||||||
|
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
|
||||||
|
k) # size (n_mfcc, n_mels)
|
||||||
|
if norm is None:
|
||||||
|
dct *= 2.0
|
||||||
|
else:
|
||||||
|
assert norm == "ortho"
|
||||||
|
dct[0] *= 1.0 / math.sqrt(2.0)
|
||||||
|
dct *= math.sqrt(2.0 / float(n_mels))
|
||||||
|
return dct.T
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .dtw import dtw_distance
|
||||||
|
from .mcd import mcd_distance
|
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
from dtaidistance import dtw_ndim
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'dtw_distance',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float:
|
||||||
|
"""dtw distance
|
||||||
|
|
||||||
|
Dynamic Time Warping.
|
||||||
|
This function keeps a compact matrix, not the full warping paths matrix.
|
||||||
|
Uses dynamic programming to compute:
|
||||||
|
|
||||||
|
wps[i, j] = (s1[i]-s2[j])**2 + min(
|
||||||
|
wps[i-1, j ] + penalty, // vertical / insertion / expansion
|
||||||
|
wps[i , j-1] + penalty, // horizontal / deletion / compression
|
||||||
|
wps[i-1, j-1]) // diagonal / match
|
||||||
|
dtw = sqrt(wps[-1, -1])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (np.ndarray): ref sequence, [T,D]
|
||||||
|
ys (np.ndarray): hyp sequence, [T,D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: dtw distance
|
||||||
|
"""
|
||||||
|
return dtw_ndim.distance(xs, ys)
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import mcd.metrics_fast as mt
|
||||||
|
import numpy as np
|
||||||
|
from mcd import dtw
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'mcd_distance',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def mcd_distance(xs: np.ndarray, ys: np.ndarray, cost_fn=mt.logSpecDbDist):
|
||||||
|
"""Mel cepstral distortion (MCD), dtw distance.
|
||||||
|
|
||||||
|
Dynamic Time Warping.
|
||||||
|
Uses dynamic programming to compute:
|
||||||
|
wps[i, j] = cost_fn(xs[i], ys[j]) + min(
|
||||||
|
wps[i-1, j ], // vertical / insertion / expansion
|
||||||
|
wps[i , j-1], // horizontal / deletion / compression
|
||||||
|
wps[i-1, j-1]) // diagonal / match
|
||||||
|
dtw = sqrt(wps[-1, -1])
|
||||||
|
|
||||||
|
Cost Function:
|
||||||
|
logSpecDbConst = 10.0 / math.log(10.0) * math.sqrt(2.0)
|
||||||
|
def logSpecDbDist(x, y):
|
||||||
|
diff = x - y
|
||||||
|
return logSpecDbConst * math.sqrt(np.inner(diff, diff))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs (np.ndarray): ref sequence, [T,D]
|
||||||
|
ys (np.ndarray): hyp sequence, [T,D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: dtw distance
|
||||||
|
"""
|
||||||
|
min_cost, path = dtw.dtw(xs, ys, cost_fn)
|
||||||
|
return min_cost
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .download import decompress
|
||||||
|
from .download import download_and_decompress
|
||||||
|
from .download import load_state_dict_from_url
|
||||||
|
from .env import DATA_HOME
|
||||||
|
from .env import MODEL_HOME
|
||||||
|
from .env import PPAUDIO_HOME
|
||||||
|
from .env import USER_HOME
|
||||||
|
from .error import ParameterError
|
||||||
|
from .log import Logger
|
||||||
|
from .log import logger
|
||||||
|
from .time import seconds_to_hms
|
||||||
|
from .time import Timer
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
mono_channel_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||||
|
multi_channels_wav = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav'
|
||||||
|
|
||||||
|
|
||||||
|
class BackendTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.initWavInput()
|
||||||
|
|
||||||
|
def initWavInput(self):
|
||||||
|
self.files = []
|
||||||
|
for url in [mono_channel_wav, multi_channels_wav]:
|
||||||
|
if not os.path.isfile(os.path.basename(url)):
|
||||||
|
urllib.request.urlretrieve(url, os.path.basename(url))
|
||||||
|
self.files.append(os.path.basename(url))
|
||||||
|
|
||||||
|
def initParmas(self):
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import filecmp
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
import paddleaudio
|
||||||
|
from ..base import BackendTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestIO(BackendTest):
|
||||||
|
def test_load_mono_channel(self):
|
||||||
|
sf_data, sf_sr = sf.read(self.files[0])
|
||||||
|
pa_data, pa_sr = paddleaudio.load(
|
||||||
|
self.files[0], normal=False, dtype='float64')
|
||||||
|
|
||||||
|
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||||
|
self.assertEqual(sf_sr, pa_sr)
|
||||||
|
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||||
|
|
||||||
|
def test_load_multi_channels(self):
|
||||||
|
sf_data, sf_sr = sf.read(self.files[1])
|
||||||
|
sf_data = sf_data.T # Channel dim first
|
||||||
|
pa_data, pa_sr = paddleaudio.load(
|
||||||
|
self.files[1], mono=False, normal=False, dtype='float64')
|
||||||
|
|
||||||
|
self.assertEqual(sf_data.dtype, pa_data.dtype)
|
||||||
|
self.assertEqual(sf_sr, pa_sr)
|
||||||
|
np.testing.assert_array_almost_equal(sf_data, pa_data)
|
||||||
|
|
||||||
|
def test_save_mono_channel(self):
|
||||||
|
waveform, sr = np.random.randint(
|
||||||
|
low=-32768, high=32768, size=(48000), dtype=np.int16), 16000
|
||||||
|
sf_tmp_file = 'sf_tmp.wav'
|
||||||
|
pa_tmp_file = 'pa_tmp.wav'
|
||||||
|
|
||||||
|
sf.write(sf_tmp_file, waveform, sr)
|
||||||
|
paddleaudio.save(waveform, sr, pa_tmp_file)
|
||||||
|
|
||||||
|
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||||
|
for file in [sf_tmp_file, pa_tmp_file]:
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
|
def test_save_multi_channels(self):
|
||||||
|
waveform, sr = np.random.randint(
|
||||||
|
low=-32768, high=32768, size=(2, 48000), dtype=np.int16), 16000
|
||||||
|
sf_tmp_file = 'sf_tmp.wav'
|
||||||
|
pa_tmp_file = 'pa_tmp.wav'
|
||||||
|
|
||||||
|
sf.write(sf_tmp_file, waveform.T, sr)
|
||||||
|
paddleaudio.save(waveform.T, sr, pa_tmp_file)
|
||||||
|
|
||||||
|
self.assertTrue(filecmp.cmp(sf_tmp_file, pa_tmp_file))
|
||||||
|
for file in [sf_tmp_file, pa_tmp_file]:
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,39 @@
|
|||||||
|
# 1. Prepare
|
||||||
|
First, install `pytest-benchmark` via pip.
|
||||||
|
```sh
|
||||||
|
pip install pytest-benchmark
|
||||||
|
```
|
||||||
|
|
||||||
|
# 2. Run
|
||||||
|
Run the specific script for profiling.
|
||||||
|
```sh
|
||||||
|
pytest melspectrogram.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Result:
|
||||||
|
```sh
|
||||||
|
========================================================================== test session starts ==========================================================================
|
||||||
|
platform linux -- Python 3.7.7, pytest-7.0.1, pluggy-1.0.0
|
||||||
|
benchmark: 3.4.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
|
||||||
|
rootdir: /ssd3/chenxiaojie06/PaddleSpeech/DeepSpeech/paddleaudio
|
||||||
|
plugins: typeguard-2.12.1, benchmark-3.4.1, anyio-3.5.0
|
||||||
|
collected 4 items
|
||||||
|
|
||||||
|
melspectrogram.py .... [100%]
|
||||||
|
|
||||||
|
|
||||||
|
-------------------------------------------------------------------------------------------------- benchmark: 4 tests -------------------------------------------------------------------------------------------------
|
||||||
|
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
test_melspect_gpu_torchaudio 202.0765 (1.0) 360.6230 (1.0) 218.1168 (1.0) 16.3022 (1.0) 214.2871 (1.0) 21.8451 (1.0) 40;3 4,584.7001 (1.0) 286 1
|
||||||
|
test_melspect_gpu 657.8509 (3.26) 908.0470 (2.52) 724.2545 (3.32) 106.5771 (6.54) 669.9096 (3.13) 113.4719 (5.19) 1;0 1,380.7300 (0.30) 5 1
|
||||||
|
test_melspect_cpu_torchaudio 1,247.6053 (6.17) 2,892.5799 (8.02) 1,443.2853 (6.62) 345.3732 (21.19) 1,262.7263 (5.89) 221.6385 (10.15) 56;53 692.8637 (0.15) 399 1
|
||||||
|
test_melspect_cpu 20,326.2549 (100.59) 20,607.8682 (57.15) 20,473.4125 (93.86) 63.8654 (3.92) 20,467.0429 (95.51) 68.4294 (3.13) 8;1 48.8438 (0.01) 29 1
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Legend:
|
||||||
|
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
|
||||||
|
OPS: Operations Per Second, computed as 1 / Mean
|
||||||
|
========================================================================== 4 passed in 21.12s ===========================================================================
|
||||||
|
|
||||||
|
```
|
@ -0,0 +1,124 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
import paddleaudio
|
||||||
|
|
||||||
|
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||||
|
if not os.path.isfile(os.path.basename(wav_url)):
|
||||||
|
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||||
|
|
||||||
|
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||||
|
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||||
|
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||||
|
|
||||||
|
# Feature conf
|
||||||
|
mel_conf = {
|
||||||
|
'sr': sr,
|
||||||
|
'n_fft': 512,
|
||||||
|
'hop_length': 128,
|
||||||
|
'n_mels': 40,
|
||||||
|
}
|
||||||
|
|
||||||
|
mel_conf_torchaudio = {
|
||||||
|
'sample_rate': sr,
|
||||||
|
'n_fft': 512,
|
||||||
|
'hop_length': 128,
|
||||||
|
'n_mels': 40,
|
||||||
|
'norm': 'slaney',
|
||||||
|
'mel_scale': 'slaney',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def enable_cpu_device():
|
||||||
|
paddle.set_device('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def enable_gpu_device():
|
||||||
|
paddle.set_device('gpu')
|
||||||
|
|
||||||
|
|
||||||
|
log_mel_extractor = paddleaudio.features.LogMelSpectrogram(
|
||||||
|
**mel_conf, f_min=0.0, top_db=80.0, dtype=waveform_tensor.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def log_melspectrogram():
|
||||||
|
return log_mel_extractor(waveform_tensor).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_melspect_cpu(benchmark):
|
||||||
|
enable_cpu_device()
|
||||||
|
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_melspect_gpu(benchmark):
|
||||||
|
enable_gpu_device()
|
||||||
|
feature_paddleaudio = benchmark(log_melspectrogram)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
|
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||||
|
**mel_conf_torchaudio, f_min=0.0)
|
||||||
|
amplitude_to_DB = torchaudio.transforms.AmplitudeToDB('power', top_db=80.0)
|
||||||
|
|
||||||
|
|
||||||
|
def melspectrogram_torchaudio():
|
||||||
|
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def log_melspectrogram_torchaudio():
|
||||||
|
mel_specgram = mel_extractor_torchaudio(waveform_tensor_torch)
|
||||||
|
return amplitude_to_DB(mel_specgram).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_melspect_cpu_torchaudio(benchmark):
|
||||||
|
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||||
|
|
||||||
|
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||||
|
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||||
|
amplitude_to_DB = amplitude_to_DB.to('cpu')
|
||||||
|
|
||||||
|
feature_paddleaudio = benchmark(log_melspectrogram_torchaudio)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_melspect_gpu_torchaudio(benchmark):
|
||||||
|
global waveform_tensor_torch, mel_extractor_torchaudio, amplitude_to_DB
|
||||||
|
|
||||||
|
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||||
|
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||||
|
amplitude_to_DB = amplitude_to_DB.to('cuda')
|
||||||
|
|
||||||
|
feature_torchaudio = benchmark(log_melspectrogram_torchaudio)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
feature_librosa = librosa.power_to_db(feature_librosa, top_db=80.0)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_torchaudio.cpu(), decimal=2)
|
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
import paddleaudio
|
||||||
|
|
||||||
|
wav_url = 'https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav'
|
||||||
|
if not os.path.isfile(os.path.basename(wav_url)):
|
||||||
|
urllib.request.urlretrieve(wav_url, os.path.basename(wav_url))
|
||||||
|
|
||||||
|
waveform, sr = paddleaudio.load(os.path.abspath(os.path.basename(wav_url)))
|
||||||
|
waveform_tensor = paddle.to_tensor(waveform).unsqueeze(0)
|
||||||
|
waveform_tensor_torch = torch.from_numpy(waveform).unsqueeze(0)
|
||||||
|
|
||||||
|
# Feature conf
|
||||||
|
mel_conf = {
|
||||||
|
'sr': sr,
|
||||||
|
'n_fft': 512,
|
||||||
|
'hop_length': 128,
|
||||||
|
'n_mels': 40,
|
||||||
|
}
|
||||||
|
|
||||||
|
mel_conf_torchaudio = {
|
||||||
|
'sample_rate': sr,
|
||||||
|
'n_fft': 512,
|
||||||
|
'hop_length': 128,
|
||||||
|
'n_mels': 40,
|
||||||
|
'norm': 'slaney',
|
||||||
|
'mel_scale': 'slaney',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def enable_cpu_device():
|
||||||
|
paddle.set_device('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def enable_gpu_device():
|
||||||
|
paddle.set_device('gpu')
|
||||||
|
|
||||||
|
|
||||||
|
mel_extractor = paddleaudio.features.MelSpectrogram(
|
||||||
|
**mel_conf, f_min=0.0, dtype=waveform_tensor.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def melspectrogram():
|
||||||
|
return mel_extractor(waveform_tensor).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_melspect_cpu(benchmark):
|
||||||
|
enable_cpu_device()
|
||||||
|
feature_paddleaudio = benchmark(melspectrogram)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_melspect_gpu(benchmark):
|
||||||
|
enable_gpu_device()
|
||||||
|
feature_paddleaudio = benchmark(melspectrogram)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=3)
|
||||||
|
|
||||||
|
|
||||||
|
mel_extractor_torchaudio = torchaudio.transforms.MelSpectrogram(
|
||||||
|
**mel_conf_torchaudio, f_min=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def melspectrogram_torchaudio():
|
||||||
|
return mel_extractor_torchaudio(waveform_tensor_torch).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_melspect_cpu_torchaudio(benchmark):
|
||||||
|
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||||
|
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cpu')
|
||||||
|
waveform_tensor_torch = waveform_tensor_torch.to('cpu')
|
||||||
|
feature_paddleaudio = benchmark(melspectrogram_torchaudio)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_paddleaudio, decimal=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_melspect_gpu_torchaudio(benchmark):
|
||||||
|
global waveform_tensor_torch, mel_extractor_torchaudio
|
||||||
|
mel_extractor_torchaudio = mel_extractor_torchaudio.to('cuda')
|
||||||
|
waveform_tensor_torch = waveform_tensor_torch.to('cuda')
|
||||||
|
feature_torchaudio = benchmark(melspectrogram_torchaudio)
|
||||||
|
feature_librosa = librosa.feature.melspectrogram(waveform, **mel_conf)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
feature_librosa, feature_torchaudio.cpu(), decimal=3)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue