Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSpeech into change_init
commit
e991d82ae7
@ -0,0 +1,88 @@
|
||||
version: '3.5'
|
||||
|
||||
services:
|
||||
etcd:
|
||||
container_name: milvus-etcd
|
||||
image: quay.io/coreos/etcd:v3.5.0
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||
|
||||
minio:
|
||||
container_name: milvus-minio
|
||||
image: minio/minio:RELEASE.2020-12-03T00-03-10Z
|
||||
networks:
|
||||
app_net:
|
||||
environment:
|
||||
MINIO_ACCESS_KEY: minioadmin
|
||||
MINIO_SECRET_KEY: minioadmin
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||
command: minio server /minio_data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.0.1
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.10
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
MINIO_ADDRESS: minio:9000
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||
ports:
|
||||
- "19530:19530"
|
||||
depends_on:
|
||||
- "etcd"
|
||||
- "minio"
|
||||
|
||||
mysql:
|
||||
container_name: audio-mysql
|
||||
image: mysql:5.7
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.11
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=123456
|
||||
volumes:
|
||||
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/mysql:/var/lib/mysql
|
||||
ports:
|
||||
- "3306:3306"
|
||||
|
||||
webclient:
|
||||
container_name: audio-webclient
|
||||
image: qingen1/paddlespeech-audio-search-client:2.3
|
||||
networks:
|
||||
app_net:
|
||||
ipv4_address: 172.16.23.13
|
||||
environment:
|
||||
API_URL: 'http://127.0.0.1:8002'
|
||||
ports:
|
||||
- "8068:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost/"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
networks:
|
||||
app_net:
|
||||
driver: bridge
|
||||
ipam:
|
||||
driver: default
|
||||
config:
|
||||
- subnet: 172.16.23.0/24
|
||||
gateway: 172.16.23.1
|
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 80 KiB |
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 84 KiB |
@ -0,0 +1,12 @@
|
||||
soundfile==0.10.3.post1
|
||||
librosa==0.8.0
|
||||
numpy
|
||||
pymysql
|
||||
fastapi
|
||||
uvicorn
|
||||
diskcache==5.2.1
|
||||
pymilvus==2.0.1
|
||||
python-multipart
|
||||
typing
|
||||
starlette
|
||||
pydantic
|
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
############### Milvus Configuration ###############
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
|
||||
MILVUS_PORT = int(os.getenv("MILVUS_PORT", "19530"))
|
||||
VECTOR_DIMENSION = int(os.getenv("VECTOR_DIMENSION", "2048"))
|
||||
INDEX_FILE_SIZE = int(os.getenv("INDEX_FILE_SIZE", "1024"))
|
||||
METRIC_TYPE = os.getenv("METRIC_TYPE", "L2")
|
||||
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "audio_table")
|
||||
TOP_K = int(os.getenv("TOP_K", "10"))
|
||||
|
||||
############### MySQL Configuration ###############
|
||||
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
MYSQL_PWD = os.getenv("MYSQL_PWD", "123456")
|
||||
MYSQL_DB = os.getenv("MYSQL_DB", "mysql")
|
||||
|
||||
############### Data Path ###############
|
||||
UPLOAD_PATH = os.getenv("UPLOAD_PATH", "tmp/audio-data")
|
||||
|
||||
############### Number of Log Files ###############
|
||||
LOGS_NUM = int(os.getenv("logs_num", "0"))
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audio_embedding(path):
|
||||
"""
|
||||
Use vpr_inference to generate embedding of audio
|
||||
"""
|
||||
try:
|
||||
RESAMPLE_RATE = 16000
|
||||
audio, _ = librosa.load(path, sr=RESAMPLE_RATE, mono=True)
|
||||
|
||||
# TODO add infer/python interface to get embedding, now fake it by rand
|
||||
# vpr = ECAPATDNN(checkpoint_path=None, device='cuda')
|
||||
# embedding = vpr.inference(audio)
|
||||
np.random.seed(hash(os.path.basename(path)) % 1000000)
|
||||
embedding = np.random.rand(1, 2048)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
embedding = embedding.tolist()[0]
|
||||
return embedding
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with embedding:{e}")
|
||||
return None
|
@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from diskcache import Cache
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from milvus_helpers import MilvusHelper
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count
|
||||
from operations.drop import do_drop
|
||||
from operations.load import do_load
|
||||
from operations.search import do_search
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MODEL = None
|
||||
MILVUS_CLI = MilvusHelper()
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file
|
||||
try:
|
||||
LOGGER.info(f"Successfully load audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"upload audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/progress')
|
||||
def get_progress():
|
||||
# Get the progress of dealing with data
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
return f"current: {cache['current']}, total: {cache['total']}"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Upload data error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
Table: Optional[str] = None
|
||||
File: str
|
||||
|
||||
|
||||
@app.post('/audio/load')
|
||||
async def load_audios(item: Item):
|
||||
# Insert all the audio files under the file path to Milvus/MySQL
|
||||
try:
|
||||
total_num = do_load(item.Table, item.File, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully loaded data, total count: {total_num}")
|
||||
return {'status': True, 'msg': "Successfully loaded data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search')
|
||||
async def search_audio(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/search/local')
|
||||
async def search_local_audio(request: Request,
|
||||
query_audio_path: str,
|
||||
table_name: str=None):
|
||||
# Search the uploaded audio in Milvus/MySQL
|
||||
try:
|
||||
host = request.headers['host']
|
||||
_, paths, distances = do_search(host, table_name, query_audio_path,
|
||||
MILVUS_CLI, MYSQL_CLI)
|
||||
names = []
|
||||
for path, score in zip(paths, distances):
|
||||
names.append(os.path.basename(path))
|
||||
LOGGER.info(f"search result {path}, score {score}")
|
||||
res = dict(zip(paths, zip(names, distances)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully searched similar audio!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/audio/count')
|
||||
async def count_audio(table_name: str=None):
|
||||
# Returns the total number of vectors in the system
|
||||
try:
|
||||
num = do_count(table_name, MILVUS_CLI)
|
||||
LOGGER.info("Successfully count the number of data!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/audio/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the collection of Milvus and MySQL
|
||||
try:
|
||||
status = do_drop(table_name, MILVUS_CLI, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in Milvus and MySQL!")
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import METRIC_TYPE
|
||||
from config import MILVUS_HOST
|
||||
from config import MILVUS_PORT
|
||||
from config import VECTOR_DIMENSION
|
||||
from logs import LOGGER
|
||||
from pymilvus import Collection
|
||||
from pymilvus import CollectionSchema
|
||||
from pymilvus import connections
|
||||
from pymilvus import DataType
|
||||
from pymilvus import FieldSchema
|
||||
from pymilvus import utility
|
||||
|
||||
|
||||
class MilvusHelper:
|
||||
"""
|
||||
the basic operations of PyMilvus
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to Milvus server
|
||||
# 2. create a collection
|
||||
# 3. insert entities
|
||||
# 4. create index
|
||||
# 5. search
|
||||
# 6. delete a collection
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.collection = None
|
||||
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
LOGGER.debug(
|
||||
f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to connect Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def set_collection(self, collection_name):
|
||||
try:
|
||||
if self.has_collection(collection_name):
|
||||
self.collection = Collection(name=collection_name)
|
||||
else:
|
||||
raise Exception(
|
||||
f"There is no collection named:{collection_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to set collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def has_collection(self, collection_name):
|
||||
# Return if Milvus has the collection
|
||||
try:
|
||||
return utility.has_collection(collection_name)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to check state of collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_collection(self, collection_name):
|
||||
# Create milvus collection if not exists
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
field1 = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
descrition="int64",
|
||||
is_primary=True,
|
||||
auto_id=True)
|
||||
field2 = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
descrition="speaker embeddings",
|
||||
dim=VECTOR_DIMENSION,
|
||||
is_primary=False)
|
||||
schema = CollectionSchema(
|
||||
fields=[field1, field2], description="embeddings info")
|
||||
self.collection = Collection(
|
||||
name=collection_name, schema=schema)
|
||||
LOGGER.debug(f"Create Milvus collection: {collection_name}")
|
||||
else:
|
||||
self.set_collection(collection_name)
|
||||
return "OK"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create collection in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def insert(self, collection_name, vectors):
|
||||
# Batch insert vectors to milvus collection
|
||||
try:
|
||||
self.create_collection(collection_name)
|
||||
data = [vectors]
|
||||
self.set_collection(collection_name)
|
||||
mr = self.collection.insert(data)
|
||||
ids = mr.primary_keys
|
||||
self.collection.load()
|
||||
LOGGER.debug(
|
||||
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows"
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to insert data to Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def create_index(self, collection_name):
|
||||
# Create IVF_FLAT index on milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
default_index = {
|
||||
"index_type": "IVF_SQ8",
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nlist": 16384
|
||||
}
|
||||
}
|
||||
status = self.collection.create_index(
|
||||
field_name="embedding", index_params=default_index)
|
||||
if not status.code:
|
||||
LOGGER.debug(
|
||||
f"Successfully create index in collection:{collection_name} with param:{default_index}"
|
||||
)
|
||||
return status
|
||||
else:
|
||||
raise Exception(status.message)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to create index: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_collection(self, collection_name):
|
||||
# Delete Milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
self.collection.drop()
|
||||
LOGGER.debug("Successfully drop collection!")
|
||||
return "ok"
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to drop collection: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_vectors(self, collection_name, vectors, top_k):
|
||||
# Search vector in milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
search_params = {
|
||||
"metric_type": METRIC_TYPE,
|
||||
"params": {
|
||||
"nprobe": 16
|
||||
}
|
||||
}
|
||||
res = self.collection.search(
|
||||
vectors,
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k)
|
||||
LOGGER.debug(f"Successfully search in collection: {res}")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
def count(self, collection_name):
|
||||
# Get the number of milvus collection
|
||||
try:
|
||||
self.set_collection(collection_name)
|
||||
num = self.collection.num_entities
|
||||
LOGGER.debug(
|
||||
f"Successfully get the num:{num} of the collection:{collection_name}"
|
||||
)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import pymysql
|
||||
from config import MYSQL_DB
|
||||
from config import MYSQL_HOST
|
||||
from config import MYSQL_PORT
|
||||
from config import MYSQL_PWD
|
||||
from config import MYSQL_USER
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
class MySQLHelper():
|
||||
"""
|
||||
the basic operations of PyMySQL
|
||||
|
||||
# This example shows how to:
|
||||
# 1. connect to MySQL server
|
||||
# 2. create a table
|
||||
# 3. insert data to table
|
||||
# 4. search by milvus ids
|
||||
# 5. delete table
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def test_connection(self):
|
||||
try:
|
||||
self.conn.ping()
|
||||
except Exception:
|
||||
self.conn = pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
user=MYSQL_USER,
|
||||
port=MYSQL_PORT,
|
||||
password=MYSQL_PWD,
|
||||
database=MYSQL_DB,
|
||||
local_infile=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def create_mysql_table(self, table_name):
|
||||
# Create mysql table if not exists
|
||||
self.test_connection()
|
||||
sql = "create table if not exists " + table_name + "(milvus_id TEXT, audio_path TEXT);"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def load_data_to_mysql(self, table_name, data):
|
||||
# Batch insert (Milvus_ids, img_path) to mysql
|
||||
self.test_connection()
|
||||
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
|
||||
try:
|
||||
self.cursor.executemany(sql, data)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(
|
||||
f"MYSQL loads data to table: {table_name} successfully")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def search_by_milvus_ids(self, ids, table_name):
|
||||
# Get the img_path according to the milvus ids
|
||||
self.test_connection()
|
||||
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
results = [res[0] for res in results]
|
||||
LOGGER.debug("MYSQL search by milvus id.")
|
||||
return results
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_table(self, table_name):
|
||||
# Delete mysql table if exists
|
||||
self.test_connection()
|
||||
sql = "drop table if exists " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
LOGGER.debug(f"MYSQL delete table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def delete_all_data(self, table_name):
|
||||
# Delete all the data in mysql table
|
||||
self.test_connection()
|
||||
sql = 'delete from ' + table_name + ';'
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
self.conn.commit()
|
||||
LOGGER.debug(f"MYSQL delete all data in table:{table_name}")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
||||
|
||||
def count_table(self, table_name):
|
||||
# Get the number of mysql table
|
||||
self.test_connection()
|
||||
sql = "select count(milvus_id) from " + table_name + ";"
|
||||
try:
|
||||
self.cursor.execute(sql)
|
||||
results = self.cursor.fetchall()
|
||||
LOGGER.debug(f"MYSQL count table:{table_name}")
|
||||
return results[0][0]
|
||||
except Exception as e:
|
||||
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
|
||||
sys.exit(1)
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_count(table_name, milvus_cli):
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return None
|
||||
num = milvus_cli.count(table_name)
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to count table {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_drop(table_name, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
try:
|
||||
if not milvus_cli.has_collection(table_name):
|
||||
return "Collection is not exist"
|
||||
status = milvus_cli.delete_collection(table_name)
|
||||
mysql_cli.delete_table(table_name)
|
||||
return status
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error attempting to drop table: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,85 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from diskcache import Cache
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def get_audios(path):
|
||||
"""
|
||||
List all wav and aif files recursively under the path folder.
|
||||
"""
|
||||
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
|
||||
return [
|
||||
item
|
||||
for sublist in [[os.path.join(dir, file) for file in files]
|
||||
for dir, _, files in list(os.walk(path))]
|
||||
for item in sublist if os.path.splitext(item)[1] in supported_formats
|
||||
]
|
||||
|
||||
|
||||
def extract_features(audio_dir):
|
||||
"""
|
||||
Get the vector of audio
|
||||
"""
|
||||
try:
|
||||
cache = Cache('./tmp')
|
||||
feats = []
|
||||
names = []
|
||||
audio_list = get_audios(audio_dir)
|
||||
total = len(audio_list)
|
||||
cache['total'] = total
|
||||
for i, audio_path in enumerate(audio_list):
|
||||
norm_feat = get_audio_embedding(audio_path)
|
||||
if norm_feat is None:
|
||||
continue
|
||||
feats.append(norm_feat)
|
||||
names.append(audio_path.encode())
|
||||
cache['current'] = i + 1
|
||||
print(
|
||||
f"Extracting feature from audio No. {i + 1} , {total} audios in total"
|
||||
)
|
||||
return feats, names
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with extracting feature from audio {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def format_data(ids, names):
|
||||
"""
|
||||
Combine the id of the vector and the name of the audio into a list
|
||||
"""
|
||||
data = []
|
||||
for i in range(len(ids)):
|
||||
value = (str(ids[i]), names[i])
|
||||
data.append(value)
|
||||
return data
|
||||
|
||||
|
||||
def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Import vectors to Milvus and data to Mysql respectively
|
||||
"""
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
vectors, names = extract_features(audio_dir)
|
||||
ids = milvus_cli.insert(table_name, vectors)
|
||||
milvus_cli.create_index(table_name)
|
||||
mysql_cli.create_mysql_table(table_name)
|
||||
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
|
||||
return len(ids)
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from config import DEFAULT_TABLE
|
||||
from config import TOP_K
|
||||
from encode import get_audio_embedding
|
||||
from logs import LOGGER
|
||||
|
||||
|
||||
def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
try:
|
||||
if not table_name:
|
||||
table_name = DEFAULT_TABLE
|
||||
feat = get_audio_embedding(audio_path)
|
||||
vectors = milvus_cli.search_vectors(table_name, [feat], TOP_K)
|
||||
vids = [str(x.id) for x in vectors[0]]
|
||||
paths = mysql_cli.search_by_milvus_ids(vids, table_name)
|
||||
distances = [x.distance for x in vectors[0]]
|
||||
for i in range(len(paths)):
|
||||
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
|
||||
paths[i] = tmp
|
||||
distances[i] = (1 - distances[i]) * 100
|
||||
return vids, paths, distances
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error with search: {e}")
|
||||
sys.exit(1)
|
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
download audio data
|
||||
"""
|
||||
url = 'https://drive.google.com/uc?id=1bKu21JWBfcZBuEuzFEvPoAX6PmRrgnUp'
|
||||
gdown.download(url)
|
||||
|
||||
with zipfile.ZipFile('example_audio.zip', 'r') as zip_ref:
|
||||
zip_ref.extractall('./example_audio')
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the collection of Milvus and MySQL
|
||||
"""
|
||||
response = client.post("/audio/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_load():
|
||||
"""
|
||||
Insert all the audio files under the file path to Milvus/MySQL
|
||||
"""
|
||||
response = client.post("/audio/load", json={"File": "./example_audio"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully loaded data!"
|
||||
}
|
||||
|
||||
|
||||
def test_progress():
|
||||
"""
|
||||
Get the progress of dealing with data
|
||||
"""
|
||||
response = client.get("/progress")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "current: 20, total: 20"
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Returns the total number of vectors in the system
|
||||
"""
|
||||
response = client.get("audio/count")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == 20
|
||||
|
||||
|
||||
def test_search():
|
||||
"""
|
||||
Search the uploaded audio in Milvus/MySQL
|
||||
"""
|
||||
response = client.post(
|
||||
"/audio/search/local?query_audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 10
|
||||
|
||||
|
||||
def test_data():
|
||||
"""
|
||||
Get the audio file
|
||||
"""
|
||||
response = client.get("/data?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
test_load()
|
||||
test_count()
|
||||
test_search()
|
||||
test_drop()
|
@ -0,0 +1 @@
|
||||
*.wav
|
@ -0,0 +1 @@
|
||||
tools/valgrind*
|
@ -0,0 +1,61 @@
|
||||
# SpeechX -- All in One Speech Task Inference
|
||||
|
||||
## Environment
|
||||
|
||||
We develop under:
|
||||
* docker - registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7
|
||||
* os - Ubuntu 16.04.7 LTS
|
||||
* gcc/g++ - 8.2.0
|
||||
* cmake - 3.16.0
|
||||
|
||||
> We make sure all things work fun under docker, and recommend using it to develop and deploy.
|
||||
|
||||
* [How to Install Docker](https://docs.docker.com/engine/install/)
|
||||
* [A Docker Tutorial for Beginners](https://docker-curriculum.com/)
|
||||
* [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html)
|
||||
|
||||
## Build
|
||||
|
||||
1. First to launch docker container.
|
||||
|
||||
```
|
||||
nvidia-docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --name=dev registry.baidubce.com/paddlepaddle/paddle:2.1.1-gpu-cuda10.2-cudnn7 /bin/bash
|
||||
```
|
||||
|
||||
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
|
||||
|
||||
* If you want only work under cpu, please download corresponded [image](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html), and using `docker` instead `nviida-docker`.
|
||||
|
||||
|
||||
2. Build `speechx` and `examples`.
|
||||
|
||||
```
|
||||
pushd /path/to/speechx
|
||||
./build.sh
|
||||
```
|
||||
|
||||
3. Go to `examples` to have a fun.
|
||||
|
||||
More details please see `README.md` under `examples`.
|
||||
|
||||
|
||||
## Valgrind (Optional)
|
||||
|
||||
> If using docker please check `--privileged` is set when `docker run`.
|
||||
|
||||
* Fatal error at startup: `a function redirection which is mandatory for this platform-tool combination cannot be set up`
|
||||
```
|
||||
apt-get install libc6-dbg
|
||||
```
|
||||
|
||||
* Install
|
||||
|
||||
```
|
||||
pushd tools
|
||||
./setup_valgrind.sh
|
||||
popd
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
* DecibelNormalizer: there is a little bit difference between offline and online db norm. The computation of online db norm read feature chunk by chunk, which causes the feature size is different with offline db norm. In normalizer.cc:73, the samples.size() is different, which causes the difference of result.
|
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# the build script had verified in the paddlepaddle docker image.
|
||||
# please follow the instruction below to install PaddlePaddle image.
|
||||
# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html
|
||||
|
||||
boost_SOURCE_DIR=$PWD/fc_patch/boost-src
|
||||
if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
|
||||
tar xzfv boost_1_75_0.tar.gz
|
||||
mkdir -p $PWD/fc_patch
|
||||
mv boost_1_75_0 ${boost_SOURCE_DIR}
|
||||
cd ${boost_SOURCE_DIR}
|
||||
bash ./bootstrap.sh
|
||||
./b2
|
||||
cd -
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
#rm -rf build
|
||||
mkdir -p build
|
||||
cd build
|
||||
|
||||
cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
|
||||
#cmake ..
|
||||
|
||||
make -j1
|
||||
|
||||
cd -
|
@ -0,0 +1 @@
|
||||
cmake_policy(SET CMP0048 NEW)
|
@ -0,0 +1,16 @@
|
||||
include(FetchContent)
|
||||
|
||||
|
||||
set(BUILD_SHARED_LIBS OFF) # up to you
|
||||
set(BUILD_TESTING OFF) # to disable abseil test, or gtest will fail.
|
||||
set(ABSL_ENABLE_INSTALL ON) # now you can enable install rules even in subproject...
|
||||
|
||||
FetchContent_Declare(
|
||||
absl
|
||||
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
|
||||
GIT_TAG "20210324.1"
|
||||
)
|
||||
FetchContent_MakeAvailable(absl)
|
||||
|
||||
set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
|
||||
include_directories(${absl_SOURCE_DIR})
|
@ -0,0 +1,27 @@
|
||||
include(FetchContent)
|
||||
set(Boost_DEBUG ON)
|
||||
|
||||
set(Boost_PREFIX_DIR ${fc_patch}/boost)
|
||||
set(Boost_SOURCE_DIR ${fc_patch}/boost-src)
|
||||
|
||||
FetchContent_Declare(
|
||||
Boost
|
||||
URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
|
||||
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
|
||||
PREFIX ${Boost_PREFIX_DIR}
|
||||
SOURCE_DIR ${Boost_SOURCE_DIR}
|
||||
)
|
||||
|
||||
execute_process(COMMAND bootstrap.sh WORKING_DIRECTORY ${Boost_SOURCE_DIR})
|
||||
execute_process(COMMAND b2 WORKING_DIRECTORY ${Boost_SOURCE_DIR})
|
||||
|
||||
FetchContent_MakeAvailable(Boost)
|
||||
|
||||
message(STATUS "boost src dir: ${Boost_SOURCE_DIR}")
|
||||
message(STATUS "boost inc dir: ${Boost_INCLUDE_DIR}")
|
||||
message(STATUS "boost bin dir: ${Boost_BINARY_DIR}")
|
||||
|
||||
set(BOOST_ROOT ${Boost_SOURCE_DIR})
|
||||
message(STATUS "boost root dir: ${BOOST_ROOT}")
|
||||
|
||||
include_directories(${Boost_SOURCE_DIR})
|
@ -0,0 +1,27 @@
|
||||
include(FetchContent)
|
||||
|
||||
# update eigen to the commit id f612df27 on 03/16/2021
|
||||
set(EIGEN_PREFIX_DIR ${fc_patch}/eigen3)
|
||||
|
||||
FetchContent_Declare(
|
||||
Eigen3
|
||||
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
|
||||
GIT_TAG master
|
||||
PREFIX ${EIGEN_PREFIX_DIR}
|
||||
GIT_SHALLOW TRUE
|
||||
GIT_PROGRESS TRUE)
|
||||
|
||||
set(EIGEN_BUILD_DOC OFF)
|
||||
# note: To disable eigen tests,
|
||||
# you should put this code in a add_subdirectory to avoid to change
|
||||
# BUILD_TESTING for your own project too since variables are directory
|
||||
# scoped
|
||||
set(BUILD_TESTING OFF)
|
||||
set(EIGEN_BUILD_PKGCONFIG OFF)
|
||||
set( OFF)
|
||||
FetchContent_MakeAvailable(Eigen3)
|
||||
|
||||
message(STATUS "eigen src dir: ${Eigen3_SOURCE_DIR}")
|
||||
message(STATUS "eigen bin dir: ${Eigen3_BINARY_DIR}")
|
||||
#include_directories(${Eigen3_SOURCE_DIR})
|
||||
#link_directories(${Eigen3_BINARY_DIR})
|
@ -0,0 +1,12 @@
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
gflags
|
||||
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
|
||||
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(gflags)
|
||||
|
||||
# openfst need
|
||||
include_directories(${gflags_BINARY_DIR}/include)
|
@ -0,0 +1,8 @@
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
glog
|
||||
URL https://github.com/google/glog/archive/v0.4.0.zip
|
||||
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
|
||||
)
|
||||
FetchContent_MakeAvailable(glog)
|
||||
include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src)
|
@ -0,0 +1,9 @@
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
gtest
|
||||
URL https://github.com/google/googletest/archive/release-1.10.0.zip
|
||||
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91
|
||||
)
|
||||
FetchContent_MakeAvailable(gtest)
|
||||
|
||||
include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src)
|
@ -0,0 +1,10 @@
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
kenlm
|
||||
GIT_REPOSITORY "https://github.com/kpu/kenlm.git"
|
||||
GIT_TAG "df2d717e95183f79a90b2fa6e4307083a351ca6a"
|
||||
)
|
||||
# https://github.com/kpu/kenlm/blob/master/cmake/modules/FindEigen3.cmake
|
||||
set(EIGEN3_INCLUDE_DIR ${Eigen3_SOURCE_DIR})
|
||||
FetchContent_MakeAvailable(kenlm)
|
||||
include_directories(${kenlm_SOURCE_DIR})
|
@ -0,0 +1,56 @@
|
||||
include(FetchContent)
|
||||
|
||||
# https://github.com/pongasoft/vst-sam-spl-64/blob/master/libsndfile.cmake
|
||||
# https://github.com/popojan/goban/blob/master/CMakeLists.txt#L38
|
||||
# https://github.com/ddiakopoulos/libnyquist/blob/master/CMakeLists.txt
|
||||
|
||||
if(LIBSNDFILE_ROOT_DIR)
|
||||
# instructs FetchContent to not download or update but use the location instead
|
||||
set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE ${LIBSNDFILE_ROOT_DIR})
|
||||
else()
|
||||
set(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE "")
|
||||
endif()
|
||||
|
||||
set(LIBSNDFILE_GIT_REPO "https://github.com/libsndfile/libsndfile.git" CACHE STRING "libsndfile git repository url" FORCE)
|
||||
set(LIBSNDFILE_GIT_TAG 1.0.31 CACHE STRING "libsndfile git tag" FORCE)
|
||||
|
||||
FetchContent_Declare(libsndfile
|
||||
GIT_REPOSITORY ${LIBSNDFILE_GIT_REPO}
|
||||
GIT_TAG ${LIBSNDFILE_GIT_TAG}
|
||||
GIT_CONFIG advice.detachedHead=false
|
||||
# GIT_SHALLOW true
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(libsndfile)
|
||||
if(NOT libsndfile_POPULATED)
|
||||
if(FETCHCONTENT_SOURCE_DIR_LIBSNDFILE)
|
||||
message(STATUS "Using libsndfile from local ${FETCHCONTENT_SOURCE_DIR_LIBSNDFILE}")
|
||||
else()
|
||||
message(STATUS "Fetching libsndfile ${LIBSNDFILE_GIT_REPO}/tree/${LIBSNDFILE_GIT_TAG}")
|
||||
endif()
|
||||
FetchContent_Populate(libsndfile)
|
||||
endif()
|
||||
|
||||
set(LIBSNDFILE_ROOT_DIR ${libsndfile_SOURCE_DIR})
|
||||
set(LIBSNDFILE_INCLUDE_DIR "${libsndfile_BINARY_DIR}/src")
|
||||
|
||||
function(libsndfile_build)
|
||||
option(BUILD_PROGRAMS "Build programs" OFF)
|
||||
option(BUILD_EXAMPLES "Build examples" OFF)
|
||||
option(BUILD_TESTING "Build examples" OFF)
|
||||
option(ENABLE_CPACK "Enable CPack support" OFF)
|
||||
option(ENABLE_PACKAGE_CONFIG "Generate and install package config file" OFF)
|
||||
option(BUILD_REGTEST "Build regtest" OFF)
|
||||
# finally we include libsndfile itself
|
||||
add_subdirectory(${libsndfile_SOURCE_DIR} ${libsndfile_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
# copying .hh for c++ support
|
||||
#file(COPY "${libsndfile_SOURCE_DIR}/src/sndfile.hh" DESTINATION ${LIBSNDFILE_INCLUDE_DIR})
|
||||
endfunction()
|
||||
|
||||
libsndfile_build()
|
||||
|
||||
include_directories(${LIBSNDFILE_INCLUDE_DIR})
|
@ -0,0 +1,37 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src)
|
||||
set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix)
|
||||
|
||||
# ######################################################################################################################
|
||||
# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
|
||||
# ######################################################################################################################
|
||||
enable_language(Fortran)
|
||||
#TODO: switch to CPM
|
||||
include(GNUInstallDirs)
|
||||
ExternalProject_Add(
|
||||
OPENBLAS
|
||||
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
|
||||
GIT_TAG v0.3.10
|
||||
GIT_SHALLOW YES
|
||||
PREFIX ${OpenBLAS_PREFIX}
|
||||
SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
|
||||
CMAKE_GENERATOR "Unix Makefiles")
|
||||
|
||||
|
||||
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
|
||||
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
|
||||
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
|
||||
add_library(openblas STATIC IMPORTED)
|
||||
add_dependencies(openblas OPENBLAS)
|
||||
set_target_properties(openblas PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES Fortran)
|
||||
# ${CMAKE_INSTALL_LIBDIR} lib
|
||||
set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a)
|
||||
|
||||
|
||||
# https://cmake.org/cmake/help/latest/command/install.html?highlight=cmake_install_libdir#installing-targets
|
||||
# ${CMAKE_INSTALL_LIBDIR} lib
|
||||
# ${CMAKE_INSTALL_INCLUDEDIR} include
|
||||
link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
|
||||
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
|
@ -0,0 +1,19 @@
|
||||
include(FetchContent)
|
||||
set(openfst_SOURCE_DIR ${fc_patch}/openfst-src)
|
||||
set(openfst_BINARY_DIR ${fc_patch}/openfst-build)
|
||||
|
||||
ExternalProject_Add(openfst
|
||||
URL https://github.com/mjansche/openfst/archive/refs/tags/1.7.2.zip
|
||||
URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6
|
||||
# #PREFIX ${openfst_PREFIX_DIR}
|
||||
# SOURCE_DIR ${openfst_SOURCE_DIR}
|
||||
# BINARY_DIR ${openfst_BINARY_DIR}
|
||||
CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR}
|
||||
"CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}"
|
||||
"LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}"
|
||||
"LIBS=-lgflags_nothreads -lglog -lpthread"
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
|
||||
BUILD_COMMAND make -j 4
|
||||
)
|
||||
link_directories(${openfst_PREFIX_DIR}/lib)
|
||||
include_directories(${openfst_PREFIX_DIR}/include)
|
@ -0,0 +1,2 @@
|
||||
*.ark
|
||||
paddle_asr_model/
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(feat)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(decoder)
|
@ -0,0 +1,16 @@
|
||||
# Examples
|
||||
|
||||
* decoder - online decoder to work as offline
|
||||
* feat - mfcc, linear
|
||||
* nnet - ds2 nn
|
||||
|
||||
## How to run
|
||||
|
||||
`run.sh` is the entry point.
|
||||
|
||||
Example to play `decoder`:
|
||||
|
||||
```
|
||||
pushd decoder
|
||||
bash run.sh
|
||||
```
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
|
||||
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "frontend/raw_audio.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_respecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
||||
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_respecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_graph;
|
||||
model_opts.params_path = model_params;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::RawDataCache> raw_data(
|
||||
new ppspeech::RawDataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data));
|
||||
|
||||
int32 chunk_size = 35;
|
||||
decoder.InitDecoder();
|
||||
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
int32 row_idx = 0;
|
||||
int32 num_chunks = feature.NumRows() / chunk_size;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
row_idx++;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
# 3. run feat
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
||||
|
||||
# 4. run decoder
|
||||
offline_decoder_main \
|
||||
--feature_respecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
||||
--dict_file=$model_dir/vocab.txt \
|
||||
--lm_path=$model_dir/avg_1.jit.klm
|
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
offline_decoder_main \
|
||||
--feature_respecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
||||
--dict_file=$model_dir/vocab.txt \
|
||||
--lm_path=$model_dir/avg_1.jit.klm
|
||||
|
@ -0,0 +1,10 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
|
||||
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
|
||||
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(mfcc-test kaldi-mfcc)
|
||||
|
||||
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
|
||||
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
|
@ -0,0 +1,720 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// feat/feature-mfcc-test.cc
|
||||
|
||||
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "base/kaldi-math.h"
|
||||
#include "feat/feature-mfcc.h"
|
||||
#include "feat/wave-reader.h"
|
||||
#include "matrix/kaldi-matrix-inl.h"
|
||||
|
||||
using namespace kaldi;
|
||||
|
||||
|
||||
static void UnitTestReadWave() {
|
||||
std::cout << "=== UnitTestReadWave() ===\n";
|
||||
|
||||
Vector<BaseFloat> v, v2;
|
||||
|
||||
std::cout << "<<<=== Reading waveform\n";
|
||||
|
||||
{
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
const Matrix<BaseFloat> data(wave.Data());
|
||||
KALDI_ASSERT(data.NumRows() == 1);
|
||||
v.Resize(data.NumCols());
|
||||
v.CopyFromVec(data.Row(0));
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
|
||||
std::ifstream input("test_data/test_matlab.ascii");
|
||||
KALDI_ASSERT(input.good());
|
||||
v2.Read(input, false);
|
||||
input.close();
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
|
||||
KALDI_ASSERT(v.Dim() == v2.Dim());
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
KALDI_ASSERT(v(i) == v2(i));
|
||||
}
|
||||
std::cout << "<<<=== Comparing done\n";
|
||||
|
||||
// std::cout << "== The Waveform Samples == \n";
|
||||
// std::cout << v;
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*/
|
||||
static void UnitTestSimple() {
|
||||
std::cout << "=== UnitTestSimple() ===\n";
|
||||
|
||||
Vector<BaseFloat> v(100000);
|
||||
Matrix<BaseFloat> m;
|
||||
|
||||
// init with noise
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
|
||||
}
|
||||
|
||||
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
|
||||
// the parametrization object
|
||||
MfccOptions op;
|
||||
// trying to have same opts as baseline.
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "rectangular";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
// use default parameters
|
||||
|
||||
// compute mfccs.
|
||||
mfcc.Compute(v, 1.0, &m);
|
||||
|
||||
// possibly dump
|
||||
// std::cout << "== Output features == \n" << m;
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare1() {
|
||||
std::cout << "=== UnitTestHTKCompare1() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.1",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.1",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.1");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare2() {
|
||||
std::cout << "=== UnitTestHTKCompare2() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.2",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.2",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.2");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare3() {
|
||||
std::cout << "=== UnitTestHTKCompare3() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.3",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 20.0;
|
||||
// op.mel_opts.debug_mel = true;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.3",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.3");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare4() {
|
||||
std::cout << "=== UnitTestHTKCompare4() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.4",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.4",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.4");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare5() {
|
||||
std::cout << "=== UnitTestHTKCompare5() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.5",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.vtln_low = 100.0;
|
||||
op.mel_opts.vtln_high = 7500.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
BaseFloat vtln_warp =
|
||||
1.1; // our approach identical to htk for warp factor >1,
|
||||
// differs slightly for higher mel bins if warp_factor <0.9
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.5",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.5");
|
||||
}
|
||||
|
||||
static void UnitTestHTKCompare6() {
|
||||
std::cout << "=== UnitTestHTKCompare6() ===\n";
|
||||
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.6",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.97;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.num_bins = 24;
|
||||
op.mel_opts.low_freq = 125.0;
|
||||
op.mel_opts.high_freq = 7800.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.6",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.6");
|
||||
}
|
||||
|
||||
void UnitTestVtln() {
|
||||
// Test the function VtlnWarpFreq.
|
||||
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
|
||||
vtln_high_cutoff = 7400;
|
||||
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq),
|
||||
freq / warp_factor);
|
||||
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
low_freq),
|
||||
low_freq);
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
high_freq),
|
||||
high_freq);
|
||||
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
|
||||
freq3 = freq2 +
|
||||
(high_freq - freq2) * RandUniform(); // freq3>=freq2
|
||||
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq2);
|
||||
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq3);
|
||||
KALDI_ASSERT(w3 >= w2); // increasing function.
|
||||
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
|
||||
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
|
||||
AssertEqual(w3dash, freq3);
|
||||
}
|
||||
}
|
||||
|
||||
static void UnitTestFeat() {
|
||||
UnitTestVtln();
|
||||
UnitTestReadWave();
|
||||
UnitTestSimple();
|
||||
UnitTestHTKCompare1();
|
||||
UnitTestHTKCompare2();
|
||||
// commenting out this one as it doesn't compare right now I normalized
|
||||
// the way the FFT bins are treated (removed offset of 0.5)... this seems
|
||||
// to relate to the way frequency zero behaves.
|
||||
UnitTestHTKCompare3();
|
||||
UnitTestHTKCompare4();
|
||||
UnitTestHTKCompare5();
|
||||
UnitTestHTKCompare6();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
try {
|
||||
for (int i = 0; i < 5; i++) UnitTestFeat();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return 1;
|
||||
}
|
||||
}
|
@ -0,0 +1,248 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "frontend/linear_spectrogram.h"
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "frontend/feature_cache.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "frontend/normalizer.h"
|
||||
#include "frontend/raw_audio.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
|
||||
|
||||
std::vector<float> mean_{
|
||||
-13730251.531853663, -12982852.199316509, -13673844.299583456,
|
||||
-13089406.559646806, -12673095.524938712, -12823859.223276224,
|
||||
-13590267.158903603, -14257618.467152044, -14374605.116185192,
|
||||
-14490009.21822485, -14849827.158924166, -15354435.470563512,
|
||||
-15834149.206532761, -16172971.985514281, -16348740.496746974,
|
||||
-16423536.699409386, -16556246.263649225, -16744088.772748645,
|
||||
-16916184.08510357, -17054034.840031497, -17165612.509455364,
|
||||
-17255955.470915023, -17322572.527648456, -17408943.862033736,
|
||||
-17521554.799865916, -17620623.254924215, -17699792.395918526,
|
||||
-17723364.411134344, -17741483.4433254, -17747426.888704527,
|
||||
-17733315.928209435, -17748780.160905756, -17808336.883775543,
|
||||
-17895918.671983004, -18009812.59173023, -18098188.66548325,
|
||||
-18195798.958462656, -18293617.62980999, -18397432.92077201,
|
||||
-18505834.787318766, -18585451.8100908, -18652438.235649142,
|
||||
-18700960.306275308, -18734944.58792185, -18737426.313365128,
|
||||
-18735347.165987637, -18738813.444170244, -18737086.848890636,
|
||||
-18731576.2474336, -18717405.44095871, -18703089.25545657,
|
||||
-18691014.546456724, -18692460.568905357, -18702119.628629155,
|
||||
-18727710.621126678, -18761582.72034647, -18806745.835547544,
|
||||
-18850674.8692112, -18884431.510951452, -18919999.992506847,
|
||||
-18939303.799078144, -18952946.273760635, -18980289.22996379,
|
||||
-19011610.17803294, -19040948.61805145, -19061021.429847397,
|
||||
-19112055.53768819, -19149667.414264943, -19201127.05091321,
|
||||
-19270250.82564605, -19334606.883057203, -19390513.336589377,
|
||||
-19444176.259208687, -19502755.000038862, -19544333.014549147,
|
||||
-19612668.183176614, -19681902.19006569, -19771969.951249883,
|
||||
-19873329.723376893, -19996752.59235844, -20110031.131400537,
|
||||
-20231658.612529557, -20319378.894054495, -20378534.45718066,
|
||||
-20413332.089584175, -20438147.844177883, -20443710.248040095,
|
||||
-20465457.02238927, -20488610.969337028, -20516295.16424432,
|
||||
-20541423.795738827, -20553192.874953747, -20573605.50701977,
|
||||
-20577871.61936797, -20571807.008916274, -20556242.38912231,
|
||||
-20542199.30819195, -20521239.063551214, -20519150.80004532,
|
||||
-20527204.80248933, -20536933.769257784, -20543470.522332076,
|
||||
-20549700.089992985, -20551525.24958494, -20554873.406493705,
|
||||
-20564277.65794227, -20572211.740052115, -20574305.69550465,
|
||||
-20575494.450104576, -20567092.577932164, -20549302.929608088,
|
||||
-20545445.11878376, -20546625.326603737, -20549190.03499401,
|
||||
-20554824.947828256, -20568341.378989458, -20577582.331383612,
|
||||
-20577980.519402675, -20566603.03458152, -20560131.592262644,
|
||||
-20552166.469060015, -20549063.06763577, -20544490.562339947,
|
||||
-20539817.82346569, -20528747.715731595, -20518026.24576161,
|
||||
-20510977.844974525, -20506874.36087992, -20506731.11977665,
|
||||
-20510482.133420516, -20507760.92101862, -20494644.834457114,
|
||||
-20480107.89304893, -20461312.091867123, -20442941.75080173,
|
||||
-20426123.02834838, -20424607.675283, -20426810.369107097,
|
||||
-20434024.50097819, -20437404.75544205, -20447688.63916367,
|
||||
-20460893.335563846, -20482922.735127095, -20503610.119434915,
|
||||
-20527062.76448319, -20557830.035128627, -20593274.72068722,
|
||||
-20632528.452965066, -20673637.471334763, -20733106.97143075,
|
||||
-20842921.0447562, -21054357.83621519, -21416569.534189366,
|
||||
-21978460.272811692, -22753170.052172784, -23671344.10563395,
|
||||
-24613499.293358143, -25406477.12230188, -25884377.82156489,
|
||||
-26049040.62791664, -26996879.104431007};
|
||||
std::vector<float> variance_{
|
||||
213747175.10846674, 188395815.34302503, 212706429.10966414,
|
||||
199109025.81461075, 189235901.23864496, 194901336.53253657,
|
||||
217481594.29306737, 238689869.12327808, 243977501.24115244,
|
||||
248479623.6431067, 259766741.47116545, 275516766.7790273,
|
||||
291271202.3691234, 302693239.8220509, 308627358.3997694,
|
||||
311143911.38788426, 315446105.07731867, 321705430.9341829,
|
||||
327458907.4659941, 332245072.43223983, 336251717.5935284,
|
||||
339694069.7639722, 342188204.4322228, 345587110.31313115,
|
||||
349903086.2875232, 353660214.20643026, 356700344.5270885,
|
||||
357665362.3529641, 358493352.05658793, 358857951.620328,
|
||||
358375239.52774596, 358899733.6342954, 361051818.3511561,
|
||||
364361716.05025816, 368750322.3771452, 372047800.6462831,
|
||||
375655861.1349018, 379358519.1980013, 383327605.3935181,
|
||||
387458599.282341, 390434692.3406868, 392994486.35057056,
|
||||
394874418.04603153, 396230525.79763395, 396365592.0414835,
|
||||
396334819.8242737, 396488353.19250053, 396438877.00744957,
|
||||
396197980.4459586, 395590921.6672991, 395001107.62072515,
|
||||
394528291.7318225, 394593110.424006, 395018405.59353715,
|
||||
396110577.5415993, 397506704.0371068, 399400197.4657644,
|
||||
401243568.2468382, 402687134.7805103, 404136047.2872507,
|
||||
404883170.001883, 405522253.219517, 406660365.3626476,
|
||||
407919346.0991902, 409045348.5384909, 409759588.7889818,
|
||||
411974821.8564483, 413489718.78201455, 415535392.56684107,
|
||||
418466481.97674364, 421104678.35678065, 423405392.5200779,
|
||||
425550570.40798235, 427929423.9579701, 429585274.253478,
|
||||
432368493.55181056, 435193587.13513297, 438886855.20476013,
|
||||
443058876.8633751, 448181232.5093362, 452883835.6332396,
|
||||
458056721.77926534, 461816531.22735566, 464363620.1970998,
|
||||
465886343.5057493, 466928872.0651, 467180536.42647296,
|
||||
468111848.70714295, 469138695.3071312, 470378429.6930793,
|
||||
471517958.7132626, 472109050.4262365, 473087417.0177867,
|
||||
473381322.04648733, 473220195.85483915, 472666071.8998819,
|
||||
472124669.87879956, 471298571.411737, 471251033.2902761,
|
||||
471672676.43128747, 472177147.2193172, 472572361.7711908,
|
||||
472968783.7751127, 473156295.4164052, 473398034.82676554,
|
||||
473897703.5203811, 474328271.33112127, 474452670.98002136,
|
||||
474549003.99284613, 474252887.13567275, 473557462.909069,
|
||||
473483385.85193115, 473609738.04855174, 473746944.82085115,
|
||||
474016729.91696435, 474617321.94138587, 475045097.237122,
|
||||
475125402.586558, 474664112.9824912, 474426247.5800283,
|
||||
474104075.42796475, 473978219.7273978, 473773171.7798875,
|
||||
473578534.69508696, 473102924.16904145, 472651240.5232615,
|
||||
472374383.1810912, 472209479.6956096, 472202298.8921673,
|
||||
472370090.76781124, 472220933.99374026, 471625467.37106377,
|
||||
470994646.51883453, 470182428.9637543, 469348211.5939578,
|
||||
468570387.4467277, 468540442.7225135, 468672018.90414184,
|
||||
468994346.9533251, 469138757.58201426, 469553915.95710236,
|
||||
470134523.38582784, 471082421.62055486, 471962316.51804745,
|
||||
472939745.1708408, 474250621.5944825, 475773933.43199486,
|
||||
477465399.71087736, 479218782.61382693, 481752299.7930922,
|
||||
486608947.8984568, 496119403.2067917, 512730085.5704984,
|
||||
539048915.2641417, 576285298.3548826, 621610270.2240586,
|
||||
669308196.4436442, 710656993.5957186, 736344437.3725077,
|
||||
745481288.0241544, 801121432.9925804};
|
||||
int count_ = 912592;
|
||||
|
||||
void WriteMatrix() {
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
|
||||
for (size_t idx = 0; idx < mean_.size(); ++idx) {
|
||||
cmvn_stats(0, idx) = mean_[idx];
|
||||
cmvn_stats(1, idx) = variance_[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_.size()) = count_;
|
||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||
WriteMatrix();
|
||||
|
||||
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
|
||||
// window -->linear_spectrogram --> cmvn
|
||||
int32 num_done = 0, num_err = 0;
|
||||
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
|
||||
// ppspeech::RawDataCache());
|
||||
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
|
||||
new ppspeech::RawAudioCache());
|
||||
|
||||
ppspeech::LinearSpectrogramOptions opt;
|
||||
opt.frame_opts.frame_length_ms = 20;
|
||||
opt.frame_opts.frame_shift_ms = 10;
|
||||
ppspeech::DecibelNormalizerOptions db_norm_opt;
|
||||
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
|
||||
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
|
||||
|
||||
std::unique_ptr<ppspeech::FeatureExtractorInterface> linear_spectrogram(
|
||||
new ppspeech::LinearSpectrogram(opt,
|
||||
std::move(base_feature_extractor)));
|
||||
|
||||
std::unique_ptr<ppspeech::FeatureExtractorInterface> cmvn(
|
||||
new ppspeech::CMVN(FLAGS_cmvn_write_path,
|
||||
std::move(linear_spectrogram)));
|
||||
|
||||
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
|
||||
|
||||
float streaming_chunk = 0.36;
|
||||
int sample_rate = 16000;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
kaldi::Vector<BaseFloat> features;
|
||||
feature_cache.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
feature_cache.SetFinished();
|
||||
}
|
||||
feature_cache.Read(&features);
|
||||
if (features.Dim() == 0) break;
|
||||
|
||||
feats.push_back(features);
|
||||
sample_offset += cur_chunk_size;
|
||||
feature_rows += features.Dim() / feature_cache.Dim();
|
||||
}
|
||||
|
||||
int cur_idx = 0;
|
||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
||||
feature_cache.Dim());
|
||||
for (auto feat : feats) {
|
||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
||||
++col_idx) {
|
||||
features(cur_idx, col_idx) =
|
||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
||||
}
|
||||
++cur_idx;
|
||||
}
|
||||
}
|
||||
feat_writer.Write(utt, features);
|
||||
|
||||
if (num_done % 50 == 0 && num_done != 0)
|
||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||
num_done++;
|
||||
}
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
# 3. run feat
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
@ -0,0 +1,24 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
||||
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
|
||||
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
|
@ -0,0 +1,14 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
@ -0,0 +1,193 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include "paddle_inference_api.h"
|
||||
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
|
||||
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data);
|
||||
void model_forward_test();
|
||||
|
||||
void produce_data(std::vector<std::vector<float>>* data) {
|
||||
int chunk_size = 35; // chunk_size in frame
|
||||
int col_size = 161; // feat dim
|
||||
cout << "chunk size: " << chunk_size << endl;
|
||||
cout << "feat dim: " << col_size << endl;
|
||||
|
||||
data->reserve(chunk_size);
|
||||
data->back().reserve(col_size);
|
||||
for (int row = 0; row < chunk_size; ++row) {
|
||||
data->push_back(std::vector<float>());
|
||||
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
|
||||
data->back().push_back(0.201);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void model_forward_test() {
|
||||
std::cout << "1. read the data" << std::endl;
|
||||
std::vector<std::vector<float>> feats;
|
||||
produce_data(&feats);
|
||||
|
||||
std::cout << "2. load the model" << std::endl;
|
||||
;
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
cout << "model path: " << model_graph << endl;
|
||||
cout << "model param path : " << model_params << endl;
|
||||
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_graph, model_params);
|
||||
config.SwitchIrOptim(false);
|
||||
cout << "SwitchIrOptim: " << false << endl;
|
||||
config.DisableFCPadding();
|
||||
cout << "DisableFCPadding: " << endl;
|
||||
auto predictor = paddle_infer::CreatePredictor(config);
|
||||
|
||||
std::cout << "3. feat shape, row=" << feats.size()
|
||||
<< ",col=" << feats[0].size() << std::endl;
|
||||
std::vector<float> pp_input_mat;
|
||||
for (const auto& item : feats) {
|
||||
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
|
||||
}
|
||||
|
||||
std::cout << "4. fead the data to model" << std::endl;
|
||||
int row = feats.size();
|
||||
int col = feats[0].size();
|
||||
std::vector<std::string> input_names = predictor->GetInputNames();
|
||||
std::vector<std::string> output_names = predictor->GetOutputNames();
|
||||
for (auto name : input_names) {
|
||||
cout << "model input names: " << name << endl;
|
||||
}
|
||||
for (auto name : output_names) {
|
||||
cout << "model output names: " << name << endl;
|
||||
}
|
||||
|
||||
// input
|
||||
std::unique_ptr<paddle_infer::Tensor> input_tensor =
|
||||
predictor->GetInputHandle(input_names[0]);
|
||||
std::vector<int> INPUT_SHAPE = {1, row, col};
|
||||
input_tensor->Reshape(INPUT_SHAPE);
|
||||
input_tensor->CopyFromCpu(pp_input_mat.data());
|
||||
|
||||
// input length
|
||||
std::unique_ptr<paddle_infer::Tensor> input_len =
|
||||
predictor->GetInputHandle(input_names[1]);
|
||||
std::vector<int> input_len_size = {1};
|
||||
input_len->Reshape(input_len_size);
|
||||
std::vector<int64_t> audio_len;
|
||||
audio_len.push_back(row);
|
||||
input_len->CopyFromCpu(audio_len.data());
|
||||
|
||||
// state_h
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
|
||||
predictor->GetInputHandle(input_names[2]);
|
||||
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
|
||||
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
|
||||
int chunk_state_h_box_size =
|
||||
std::accumulate(chunk_state_h_box_shape.begin(),
|
||||
chunk_state_h_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
|
||||
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
|
||||
|
||||
// state_c
|
||||
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
|
||||
predictor->GetInputHandle(input_names[3]);
|
||||
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
|
||||
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
|
||||
int chunk_state_c_box_size =
|
||||
std::accumulate(chunk_state_c_box_shape.begin(),
|
||||
chunk_state_c_box_shape.end(),
|
||||
1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
|
||||
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
|
||||
|
||||
// run
|
||||
bool success = predictor->Run();
|
||||
|
||||
// state_h out
|
||||
std::unique_ptr<paddle_infer::Tensor> h_out =
|
||||
predictor->GetOutputHandle(output_names[2]);
|
||||
std::vector<int> h_out_shape = h_out->shape();
|
||||
int h_out_size = std::accumulate(
|
||||
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> h_out_data(h_out_size);
|
||||
h_out->CopyToCpu(h_out_data.data());
|
||||
|
||||
// stage_c out
|
||||
std::unique_ptr<paddle_infer::Tensor> c_out =
|
||||
predictor->GetOutputHandle(output_names[3]);
|
||||
std::vector<int> c_out_shape = c_out->shape();
|
||||
int c_out_size = std::accumulate(
|
||||
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
|
||||
std::vector<float> c_out_data(c_out_size);
|
||||
c_out->CopyToCpu(c_out_data.data());
|
||||
|
||||
// output tensor
|
||||
std::unique_ptr<paddle_infer::Tensor> output_tensor =
|
||||
predictor->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_tensor->shape();
|
||||
std::vector<float> output_probs;
|
||||
int output_size = std::accumulate(
|
||||
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
|
||||
output_probs.resize(output_size);
|
||||
output_tensor->CopyToCpu(output_probs.data());
|
||||
row = output_shape[1];
|
||||
col = output_shape[2];
|
||||
|
||||
// probs
|
||||
std::vector<std::vector<float>> probs;
|
||||
probs.reserve(row);
|
||||
for (int i = 0; i < row; i++) {
|
||||
probs.push_back(std::vector<float>());
|
||||
probs.back().reserve(col);
|
||||
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs.back().push_back(output_probs[i * col + j]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> log_feat = probs;
|
||||
std::cout << "probs, row: " << log_feat.size()
|
||||
<< " col: " << log_feat[0].size() << std::endl;
|
||||
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
|
||||
++col_idx) {
|
||||
std::cout << log_feat[row_idx][col_idx] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
model_forward_test();
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
|
||||
# 4. run decoder
|
||||
pp-model-test \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams
|
||||
|
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
# this script is for memory check, so please run ./run.sh first.
|
||||
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
if [ ! -d ${SPEECHX_TOOLS}/valgrind/install ]; then
|
||||
echo "please install valgrind in the speechx tools dir.\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
|
||||
valgrind --tool=memcheck --track-origins=yes --leak-check=full --show-leak-kinds=all \
|
||||
pp-model-test \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams
|
@ -0,0 +1 @@
|
||||
exclude_files=.*
|
@ -0,0 +1,228 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// See www.openfst.org for extensive documentation on this weighted
|
||||
// finite-state transducer library.
|
||||
//
|
||||
// Google-style flag handling declarations and inline definitions.
|
||||
|
||||
#ifndef FST_LIB_FLAGS_H_
|
||||
#define FST_LIB_FLAGS_H_
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include <fst/types.h>
|
||||
#include <fst/lock.h>
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
// FLAGS USAGE:
|
||||
//
|
||||
// Definition example:
|
||||
//
|
||||
// DEFINE_int32(length, 0, "length");
|
||||
//
|
||||
// This defines variable FLAGS_length, initialized to 0.
|
||||
//
|
||||
// Declaration example:
|
||||
//
|
||||
// DECLARE_int32(length);
|
||||
//
|
||||
// SET_FLAGS() can be used to set flags from the command line
|
||||
// using, for example, '--length=2'.
|
||||
//
|
||||
// ShowUsage() can be used to print out command and flag usage.
|
||||
|
||||
// #define DECLARE_bool(name) extern bool FLAGS_ ## name
|
||||
// #define DECLARE_string(name) extern string FLAGS_ ## name
|
||||
// #define DECLARE_int32(name) extern int32 FLAGS_ ## name
|
||||
// #define DECLARE_int64(name) extern int64 FLAGS_ ## name
|
||||
// #define DECLARE_double(name) extern double FLAGS_ ## name
|
||||
|
||||
template <typename T>
|
||||
struct FlagDescription {
|
||||
FlagDescription(T *addr, const char *doc, const char *type,
|
||||
const char *file, const T val)
|
||||
: address(addr),
|
||||
doc_string(doc),
|
||||
type_name(type),
|
||||
file_name(file),
|
||||
default_value(val) {}
|
||||
|
||||
T *address;
|
||||
const char *doc_string;
|
||||
const char *type_name;
|
||||
const char *file_name;
|
||||
const T default_value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FlagRegister {
|
||||
public:
|
||||
static FlagRegister<T> *GetRegister() {
|
||||
static auto reg = new FlagRegister<T>;
|
||||
return reg;
|
||||
}
|
||||
|
||||
const FlagDescription<T> &GetFlagDescription(const string &name) const {
|
||||
fst::MutexLock l(&flag_lock_);
|
||||
auto it = flag_table_.find(name);
|
||||
return it != flag_table_.end() ? it->second : 0;
|
||||
}
|
||||
|
||||
void SetDescription(const string &name,
|
||||
const FlagDescription<T> &desc) {
|
||||
fst::MutexLock l(&flag_lock_);
|
||||
flag_table_.insert(make_pair(name, desc));
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, bool *address) const {
|
||||
if (val == "true" || val == "1" || val.empty()) {
|
||||
*address = true;
|
||||
return true;
|
||||
} else if (val == "false" || val == "0") {
|
||||
*address = false;
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, string *address) const {
|
||||
*address = val;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, int32 *address) const {
|
||||
char *p = 0;
|
||||
*address = strtol(val.c_str(), &p, 0);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, int64 *address) const {
|
||||
char *p = 0;
|
||||
*address = strtoll(val.c_str(), &p, 0);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &val, double *address) const {
|
||||
char *p = 0;
|
||||
*address = strtod(val.c_str(), &p);
|
||||
return !val.empty() && *p == '\0';
|
||||
}
|
||||
|
||||
bool SetFlag(const string &arg, const string &val) const {
|
||||
for (typename std::map< string, FlagDescription<T> >::const_iterator it =
|
||||
flag_table_.begin();
|
||||
it != flag_table_.end();
|
||||
++it) {
|
||||
const string &name = it->first;
|
||||
const FlagDescription<T> &desc = it->second;
|
||||
if (arg == name)
|
||||
return SetFlag(val, desc.address);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetUsage(std::set<std::pair<string, string>> *usage_set) const {
|
||||
for (auto it = flag_table_.begin(); it != flag_table_.end(); ++it) {
|
||||
const string &name = it->first;
|
||||
const FlagDescription<T> &desc = it->second;
|
||||
string usage = " --" + name;
|
||||
usage += ": type = ";
|
||||
usage += desc.type_name;
|
||||
usage += ", default = ";
|
||||
usage += GetDefault(desc.default_value) + "\n ";
|
||||
usage += desc.doc_string;
|
||||
usage_set->insert(make_pair(desc.file_name, usage));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string GetDefault(bool default_value) const {
|
||||
return default_value ? "true" : "false";
|
||||
}
|
||||
|
||||
string GetDefault(const string &default_value) const {
|
||||
return "\"" + default_value + "\"";
|
||||
}
|
||||
|
||||
template <class V>
|
||||
string GetDefault(const V &default_value) const {
|
||||
std::ostringstream strm;
|
||||
strm << default_value;
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
mutable fst::Mutex flag_lock_; // Multithreading lock.
|
||||
std::map<string, FlagDescription<T>> flag_table_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FlagRegisterer {
|
||||
public:
|
||||
FlagRegisterer(const string &name, const FlagDescription<T> &desc) {
|
||||
auto registr = FlagRegister<T>::GetRegister();
|
||||
registr->SetDescription(name, desc);
|
||||
}
|
||||
|
||||
private:
|
||||
FlagRegisterer(const FlagRegisterer &) = delete;
|
||||
FlagRegisterer &operator=(const FlagRegisterer &) = delete;
|
||||
};
|
||||
|
||||
|
||||
#define DEFINE_VAR(type, name, value, doc) \
|
||||
type FLAGS_ ## name = value; \
|
||||
static FlagRegisterer<type> \
|
||||
name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \
|
||||
doc, \
|
||||
#type, \
|
||||
__FILE__, \
|
||||
value))
|
||||
|
||||
// #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc)
|
||||
// #define DEFINE_string(name, value, doc) \
|
||||
// DEFINE_VAR(string, name, value, doc)
|
||||
// #define DEFINE_int32(name, value, doc) DEFINE_VAR(int32, name, value, doc)
|
||||
// #define DEFINE_int64(name, value, doc) DEFINE_VAR(int64, name, value, doc)
|
||||
// #define DEFINE_double(name, value, doc) DEFINE_VAR(double, name, value, doc)
|
||||
|
||||
|
||||
// Temporary directory.
|
||||
DECLARE_string(tmpdir);
|
||||
|
||||
void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags,
|
||||
const char *src = "");
|
||||
|
||||
#define SET_FLAGS(usage, argc, argv, rmflags) \
|
||||
gflags::ParseCommandLineFlags(argc, argv, true)
|
||||
// SetFlags(usage, argc, argv, rmflags, __FILE__)
|
||||
|
||||
// Deprecated; for backward compatibility.
|
||||
inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) {
|
||||
return SetFlags(usage, argc, argv, rmflags);
|
||||
}
|
||||
|
||||
void ShowUsage(bool long_usage = true);
|
||||
|
||||
#endif // FST_LIB_FLAGS_H_
|
@ -0,0 +1,82 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// See www.openfst.org for extensive documentation on this weighted
|
||||
// finite-state transducer library.
|
||||
//
|
||||
// Google-style logging declarations and inline definitions.
|
||||
|
||||
#ifndef FST_LIB_LOG_H_
|
||||
#define FST_LIB_LOG_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <fst/types.h>
|
||||
#include <fst/flags.h>
|
||||
|
||||
using std::string;
|
||||
|
||||
DECLARE_int32(v);
|
||||
|
||||
class LogMessage {
|
||||
public:
|
||||
LogMessage(const string &type) : fatal_(type == "FATAL") {
|
||||
std::cerr << type << ": ";
|
||||
}
|
||||
~LogMessage() {
|
||||
std::cerr << std::endl;
|
||||
if(fatal_)
|
||||
exit(1);
|
||||
}
|
||||
std::ostream &stream() { return std::cerr; }
|
||||
|
||||
private:
|
||||
bool fatal_;
|
||||
};
|
||||
|
||||
// #define LOG(type) LogMessage(#type).stream()
|
||||
// #define VLOG(level) if ((level) <= FLAGS_v) LOG(INFO)
|
||||
|
||||
// Checks
|
||||
inline void FstCheck(bool x, const char* expr,
|
||||
const char *file, int line) {
|
||||
if (!x) {
|
||||
LOG(FATAL) << "Check failed: \"" << expr
|
||||
<< "\" file: " << file
|
||||
<< " line: " << line;
|
||||
}
|
||||
}
|
||||
|
||||
// #define CHECK(x) FstCheck(static_cast<bool>(x), #x, __FILE__, __LINE__)
|
||||
// #define CHECK_EQ(x, y) CHECK((x) == (y))
|
||||
// #define CHECK_LT(x, y) CHECK((x) < (y))
|
||||
// #define CHECK_GT(x, y) CHECK((x) > (y))
|
||||
// #define CHECK_LE(x, y) CHECK((x) <= (y))
|
||||
// #define CHECK_GE(x, y) CHECK((x) >= (y))
|
||||
// #define CHECK_NE(x, y) CHECK((x) != (y))
|
||||
|
||||
// Debug checks
|
||||
// #define DCHECK(x) assert(x)
|
||||
// #define DCHECK_EQ(x, y) DCHECK((x) == (y))
|
||||
// #define DCHECK_LT(x, y) DCHECK((x) < (y))
|
||||
// #define DCHECK_GT(x, y) DCHECK((x) > (y))
|
||||
// #define DCHECK_LE(x, y) DCHECK((x) <= (y))
|
||||
// #define DCHECK_GE(x, y) DCHECK((x) >= (y))
|
||||
// #define DCHECK_NE(x, y) DCHECK((x) != (y))
|
||||
|
||||
|
||||
// Ports
|
||||
#define ATTRIBUTE_DEPRECATED __attribute__((deprecated))
|
||||
|
||||
#endif // FST_LIB_LOG_H_
|
@ -0,0 +1,166 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Google-style flag handling definitions.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#if _MSC_VER
|
||||
#include <io.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#include <fst/compat.h>
|
||||
#include <fst/flags.h>
|
||||
|
||||
static const char *private_tmpdir = getenv("TMPDIR");
|
||||
|
||||
// DEFINE_int32(v, 0, "verbosity level");
|
||||
// DEFINE_bool(help, false, "show usage information");
|
||||
// DEFINE_bool(helpshort, false, "show brief usage information");
|
||||
#ifndef _MSC_VER
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
|
||||
"temporary directory");
|
||||
#else
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
|
||||
"temporary directory");
|
||||
#endif // !_MSC_VER
|
||||
|
||||
using namespace std;
|
||||
|
||||
static string flag_usage;
|
||||
static string prog_src;
|
||||
|
||||
// Sets prog_src to src.
|
||||
static void SetProgSrc(const char *src) {
|
||||
prog_src = src;
|
||||
#if _MSC_VER
|
||||
// This common code is invoked by all FST binaries, and only by them. Switch
|
||||
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
|
||||
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
|
||||
// already using ios::binary where binary files are read or written.
|
||||
// Kudos to @daanzu for the suggested fix.
|
||||
// https://github.com/kkm000/openfst/issues/20
|
||||
// https://github.com/kkm000/openfst/pull/23
|
||||
// https://github.com/kkm000/openfst/pull/32
|
||||
_setmode(_fileno(stdin), O_BINARY);
|
||||
_setmode(_fileno(stdout), O_BINARY);
|
||||
#endif
|
||||
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
|
||||
// is called in fstx-main.cc, which results in a filename mismatch in
|
||||
// ShowUsageRestrict() below.
|
||||
static constexpr char kMainSuffix[] = "-main.cc";
|
||||
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
|
||||
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
|
||||
prog_src.erase(prefix_length, strlen("-main"));
|
||||
}
|
||||
}
|
||||
|
||||
void SetFlags(const char *usage, int *argc, char ***argv,
|
||||
bool remove_flags, const char *src) {
|
||||
flag_usage = usage;
|
||||
SetProgSrc(src);
|
||||
|
||||
int index = 1;
|
||||
for (; index < *argc; ++index) {
|
||||
string argval = (*argv)[index];
|
||||
if (argval[0] != '-' || argval == "-") break;
|
||||
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
|
||||
string arg = argval;
|
||||
string val = "";
|
||||
// Splits argval (arg=val) into arg and val.
|
||||
auto pos = argval.find("=");
|
||||
if (pos != string::npos) {
|
||||
arg = argval.substr(0, pos);
|
||||
val = argval.substr(pos + 1);
|
||||
}
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
if (bool_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
if (string_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
if (int32_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
if (int64_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
if (double_register->SetFlag(arg, val))
|
||||
continue;
|
||||
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
|
||||
}
|
||||
if (remove_flags) {
|
||||
for (auto i = 0; i < *argc - index; ++i) {
|
||||
(*argv)[i + 1] = (*argv)[i + index];
|
||||
}
|
||||
*argc -= index - 1;
|
||||
}
|
||||
// if (FLAGS_help) {
|
||||
// ShowUsage(true);
|
||||
// exit(1);
|
||||
// }
|
||||
// if (FLAGS_helpshort) {
|
||||
// ShowUsage(false);
|
||||
// exit(1);
|
||||
// }
|
||||
}
|
||||
|
||||
// If flag is defined in file 'src' and 'in_src' true or is not
|
||||
// defined in file 'src' and 'in_src' is false, then print usage.
|
||||
static void
|
||||
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
|
||||
const string &src, bool in_src, bool show_file) {
|
||||
string old_file;
|
||||
bool file_out = false;
|
||||
bool usage_out = false;
|
||||
for (const auto &pair : usage_set) {
|
||||
const auto &file = pair.first;
|
||||
const auto &usage = pair.second;
|
||||
bool match = file == src;
|
||||
if ((match && !in_src) || (!match && in_src)) continue;
|
||||
if (file != old_file) {
|
||||
if (show_file) {
|
||||
if (file_out) cout << "\n";
|
||||
cout << "Flags from: " << file << "\n";
|
||||
file_out = true;
|
||||
}
|
||||
old_file = file;
|
||||
}
|
||||
cout << usage << "\n";
|
||||
usage_out = true;
|
||||
}
|
||||
if (usage_out) cout << "\n";
|
||||
}
|
||||
|
||||
void ShowUsage(bool long_usage) {
|
||||
std::set<pair<string, string>> usage_set;
|
||||
cout << flag_usage << "\n";
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
bool_register->GetUsage(&usage_set);
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
string_register->GetUsage(&usage_set);
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
int32_register->GetUsage(&usage_set);
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
int64_register->GetUsage(&usage_set);
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
double_register->GetUsage(&usage_set);
|
||||
if (!prog_src.empty()) {
|
||||
cout << "PROGRAM FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, true, false);
|
||||
}
|
||||
if (!long_usage) return;
|
||||
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, false, true);
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <istream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <ostream>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "base/macros.h"
|
@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fst/flags.h"
|
@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fst/log.h"
|
@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
||||
|
||||
// This software is provided 'as-is', without any express or implied
|
||||
// warranty. In no event will the authors be held liable for any damages
|
||||
// arising from the use of this software.
|
||||
|
||||
// Permission is granted to anyone to use this software for any purpose,
|
||||
// including commercial applications, and to alter it and redistribute it
|
||||
// freely, subject to the following restrictions:
|
||||
|
||||
// 1. The origin of this software must not be misrepresented; you must not
|
||||
// claim that you wrote the original software. If you use this software
|
||||
// in a product, an acknowledgment in the product documentation would be
|
||||
// appreciated but is not required.
|
||||
|
||||
// 2. Altered source versions must be plainly marked as such, and must not be
|
||||
// misrepresented as being the original software.
|
||||
|
||||
// 3. This notice may not be removed or altered from any source
|
||||
// distribution.
|
||||
// this code is from https://github.com/progschj/ThreadPool
|
||||
|
||||
#ifndef BASE_THREAD_POOL_H
|
||||
#define BASE_THREAD_POOL_H
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template <class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector<std::thread> workers;
|
||||
// the task queue
|
||||
std::queue<std::function<void()>> tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
||||
for (size_t i = 0; i < threads; ++i)
|
||||
workers.emplace_back([this] {
|
||||
for (;;) {
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock, [this] {
|
||||
return this->stop || !this->tasks.empty();
|
||||
});
|
||||
if (this->stop && this->tasks.empty()) return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template <class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type> {
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task]() { (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for (std::thread& worker : workers) worker.join();
|
||||
}
|
||||
|
||||
#endif
|
@ -1,4 +0,0 @@
|
||||
# codelab
|
||||
|
||||
This directory is here for testing some funcitons temporaril.
|
||||
|
@ -1,686 +0,0 @@
|
||||
// feat/feature-mfcc-test.cc
|
||||
|
||||
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "feat/feature-mfcc.h"
|
||||
#include "base/kaldi-math.h"
|
||||
#include "matrix/kaldi-matrix-inl.h"
|
||||
#include "feat/wave-reader.h"
|
||||
|
||||
using namespace kaldi;
|
||||
|
||||
|
||||
|
||||
static void UnitTestReadWave() {
|
||||
|
||||
std::cout << "=== UnitTestReadWave() ===\n";
|
||||
|
||||
Vector<BaseFloat> v, v2;
|
||||
|
||||
std::cout << "<<<=== Reading waveform\n";
|
||||
|
||||
{
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
const Matrix<BaseFloat> data(wave.Data());
|
||||
KALDI_ASSERT(data.NumRows() == 1);
|
||||
v.Resize(data.NumCols());
|
||||
v.CopyFromVec(data.Row(0));
|
||||
}
|
||||
|
||||
std::cout << "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
|
||||
std::ifstream input(
|
||||
"test_data/test_matlab.ascii"
|
||||
);
|
||||
KALDI_ASSERT(input.good());
|
||||
v2.Read(input, false);
|
||||
input.close();
|
||||
|
||||
std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
|
||||
KALDI_ASSERT(v.Dim() == v2.Dim());
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
KALDI_ASSERT(v(i) == v2(i));
|
||||
}
|
||||
std::cout << "<<<=== Comparing done\n";
|
||||
|
||||
// std::cout << "== The Waveform Samples == \n";
|
||||
// std::cout << v;
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*/
|
||||
static void UnitTestSimple() {
|
||||
std::cout << "=== UnitTestSimple() ===\n";
|
||||
|
||||
Vector<BaseFloat> v(100000);
|
||||
Matrix<BaseFloat> m;
|
||||
|
||||
// init with noise
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2);
|
||||
}
|
||||
|
||||
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
|
||||
// the parametrization object
|
||||
MfccOptions op;
|
||||
// trying to have same opts as baseline.
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "rectangular";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
// use default parameters
|
||||
|
||||
// compute mfccs.
|
||||
mfcc.Compute(v, 1.0, &m);
|
||||
|
||||
// possibly dump
|
||||
// std::cout << "== Output features == \n" << m;
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare1() {
|
||||
std::cout << "=== UnitTestHTKCompare1() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.1",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.1",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.1");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare2() {
|
||||
std::cout << "=== UnitTestHTKCompare2() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.2",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.2",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.2");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare3() {
|
||||
std::cout << "=== UnitTestHTKCompare3() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.3",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 20.0;
|
||||
//op.mel_opts.debug_mel = true;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.3",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.3");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare4() {
|
||||
std::cout << "=== UnitTestHTKCompare4() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.4",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.4",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.4");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare5() {
|
||||
std::cout << "=== UnitTestHTKCompare5() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.5",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.vtln_low = 100.0;
|
||||
op.mel_opts.vtln_high = 7500.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1,
|
||||
// differs slightly for higher mel bins if warp_factor <0.9
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.5",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.5");
|
||||
}
|
||||
|
||||
static void UnitTestHTKCompare6() {
|
||||
std::cout << "=== UnitTestHTKCompare6() ===\n";
|
||||
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.6",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.97;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.num_bins = 24;
|
||||
op.mel_opts.low_freq = 125.0;
|
||||
op.mel_opts.high_freq = 7800.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts,
|
||||
kaldi_raw_features,
|
||||
&kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}}}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.6",
|
||||
std::ios::out|std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.6");
|
||||
}
|
||||
|
||||
void UnitTestVtln() {
|
||||
// Test the function VtlnWarpFreq.
|
||||
BaseFloat low_freq = 10, high_freq = 7800,
|
||||
vtln_low_cutoff = 20, vtln_high_cutoff = 7400;
|
||||
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, warp_factor,
|
||||
freq),
|
||||
freq / warp_factor);
|
||||
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, warp_factor,
|
||||
low_freq),
|
||||
low_freq);
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, warp_factor,
|
||||
high_freq),
|
||||
high_freq);
|
||||
BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(),
|
||||
freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2
|
||||
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, warp_factor,
|
||||
freq2);
|
||||
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, warp_factor,
|
||||
freq3);
|
||||
KALDI_ASSERT(w3 >= w2); // increasing function.
|
||||
BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
|
||||
low_freq, high_freq, 1.0,
|
||||
freq3);
|
||||
AssertEqual(w3dash, freq3);
|
||||
}
|
||||
}
|
||||
|
||||
static void UnitTestFeat() {
|
||||
UnitTestVtln();
|
||||
UnitTestReadWave();
|
||||
UnitTestSimple();
|
||||
UnitTestHTKCompare1();
|
||||
UnitTestHTKCompare2();
|
||||
// commenting out this one as it doesn't compare right now I normalized
|
||||
// the way the FFT bins are treated (removed offset of 0.5)... this seems
|
||||
// to relate to the way frequency zero behaves.
|
||||
UnitTestHTKCompare3();
|
||||
UnitTestHTKCompare4();
|
||||
UnitTestHTKCompare5();
|
||||
UnitTestHTKCompare6();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
}
|
||||
|
||||
|
||||
|
||||
int main() {
|
||||
try {
|
||||
for (int i = 0; i < 5; i++)
|
||||
UnitTestFeat();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,2 +1,10 @@
|
||||
aux_source_directory(. DIR_LIB_SRCS)
|
||||
add_library(decoder STATIC ${DIR_LIB_SRCS})
|
||||
project(decoder)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
|
||||
add_library(decoder STATIC
|
||||
ctc_beam_search_decoder.cc
|
||||
ctc_decoders/decoder_utils.cpp
|
||||
ctc_decoders/path_trie.cpp
|
||||
ctc_decoders/scorer.cpp
|
||||
)
|
||||
target_link_libraries(decoder PUBLIC kenlm utils fst)
|
@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/basic_types.h"
|
||||
|
||||
struct DecoderResult {
|
||||
BaseFloat acoustic_score;
|
||||
std::vector<int32> words_idx;
|
||||
std::vector<pair<int32, int32>> time_stamp;
|
||||
};
|
@ -0,0 +1,314 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "decoder/ctc_decoders/decoder_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using std::vector;
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
|
||||
: opts_(opts),
|
||||
init_ext_scorer_(nullptr),
|
||||
blank_id_(-1),
|
||||
space_id_(-1),
|
||||
num_frame_decoded_(0),
|
||||
root_(nullptr) {
|
||||
LOG(INFO) << "dict path: " << opts_.dict_file;
|
||||
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
||||
LOG(INFO) << "load the dict failed";
|
||||
}
|
||||
LOG(INFO) << "read the vocabulary success, dict size: "
|
||||
<< vocabulary_.size();
|
||||
|
||||
LOG(INFO) << "language model path: " << opts_.lm_path;
|
||||
init_ext_scorer_ = std::make_shared<Scorer>(
|
||||
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
|
||||
|
||||
blank_id_ = 0;
|
||||
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
||||
|
||||
space_id_ = it - vocabulary_.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id_ >= vocabulary_.size()) {
|
||||
space_id_ = -2;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Reset() {
|
||||
// num_frame_decoded_ = 0;
|
||||
// ResetPrefixes();
|
||||
InitDecoder();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::InitDecoder() {
|
||||
num_frame_decoded_ = 0;
|
||||
// ResetPrefixes();
|
||||
prefixes_.clear();
|
||||
|
||||
root_ = std::make_shared<PathTrie>();
|
||||
root_->score = root_->log_prob_b_prev = 0.0;
|
||||
prefixes_.push_back(root_.get());
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
!init_ext_scorer_->is_character_based()) {
|
||||
auto fst_dict =
|
||||
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
|
||||
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
|
||||
root_->set_dictionary(dict_ptr);
|
||||
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root_->set_matcher(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::Decode(
|
||||
std::shared_ptr<kaldi::DecodableInterface> decodable) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
|
||||
|
||||
// todo rename, refactor
|
||||
void CTCBeamSearch::AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
|
||||
while (1) {
|
||||
vector<vector<BaseFloat>> likelihood;
|
||||
vector<BaseFloat> frame_prob;
|
||||
bool flag =
|
||||
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
|
||||
if (flag == false) break;
|
||||
likelihood.push_back(frame_prob);
|
||||
AdvanceDecoding(likelihood);
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::ResetPrefixes() {
|
||||
for (size_t i = 0; i < prefixes_.size(); i++) {
|
||||
if (prefixes_[i] != nullptr) {
|
||||
delete prefixes_[i];
|
||||
prefixes_[i] = nullptr;
|
||||
}
|
||||
}
|
||||
prefixes_.clear();
|
||||
}
|
||||
|
||||
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
|
||||
vector<string>& nbest_words) {
|
||||
kaldi::Timer timer;
|
||||
timer.Reset();
|
||||
AdvanceDecoding(probs);
|
||||
LOG(INFO) << "ctc decoding elapsed time(s) "
|
||||
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
|
||||
return 0;
|
||||
}
|
||||
|
||||
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
||||
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetBestPath() {
|
||||
std::vector<std::pair<double, std::string>> result;
|
||||
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
|
||||
return result[0].second;
|
||||
}
|
||||
|
||||
string CTCBeamSearch::GetFinalBestPath() {
|
||||
CalculateApproxScore();
|
||||
LMRescore();
|
||||
return GetBestPath();
|
||||
}
|
||||
|
||||
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
||||
size_t num_time_steps = probs.size();
|
||||
size_t beam_size = opts_.beam_size;
|
||||
double cutoff_prob = opts_.cutoff_prob;
|
||||
size_t cutoff_top_n = opts_.cutoff_top_n;
|
||||
|
||||
vector<vector<double>> probs_seq(probs.size(),
|
||||
vector<double>(probs[0].size(), 0));
|
||||
|
||||
int row = probs.size();
|
||||
int col = probs[0].size();
|
||||
for (int i = 0; i < row; i++) {
|
||||
for (int j = 0; j < col; j++) {
|
||||
probs_seq[i][j] = static_cast<double>(probs[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
||||
const auto& prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
||||
std::sort(prefixes_.begin(),
|
||||
prefixes_.begin() + num_prefixes_,
|
||||
prefix_compare);
|
||||
|
||||
if (num_prefixes_ == 0) {
|
||||
continue;
|
||||
}
|
||||
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
|
||||
std::log(prob[blank_id_]) -
|
||||
std::max(0.0, init_ext_scorer_->beta);
|
||||
|
||||
full_beam = (num_prefixes_ == beam_size);
|
||||
}
|
||||
|
||||
vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
|
||||
// loop over chars
|
||||
size_t log_prob_idx_len = log_prob_idx.size();
|
||||
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
||||
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
||||
}
|
||||
|
||||
prefixes_.clear();
|
||||
|
||||
// update log probs
|
||||
root_->iterate_to_vec(prefixes_);
|
||||
// only preserve top beam_size prefixes_
|
||||
if (prefixes_.size() >= beam_size) {
|
||||
std::nth_element(prefixes_.begin(),
|
||||
prefixes_.begin() + beam_size,
|
||||
prefixes_.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
|
||||
prefixes_[i]->remove();
|
||||
}
|
||||
} // if
|
||||
num_frame_decoded_++;
|
||||
} // for probs_seq
|
||||
}
|
||||
|
||||
int32 CTCBeamSearch::SearchOneChar(
|
||||
const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff) {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
const auto& c = log_prob_idx.first;
|
||||
const auto& log_prob_c = log_prob_idx.second;
|
||||
size_t prefixes_len = std::min(prefixes_.size(), beam_size);
|
||||
|
||||
for (size_t i = 0; i < prefixes_len; ++i) {
|
||||
auto prefix = prefixes_[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (c == blank_id_) {
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
(c == space_id_ || init_ext_scorer_->is_character_based())) {
|
||||
PathTrie* prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (init_ext_scorer_->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
vector<string> ngram;
|
||||
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
|
||||
// lm score: p_{lm}(W)^{\alpha} + \beta
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||
init_ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += init_ext_scorer_->beta;
|
||||
}
|
||||
// p_{nb}(l;x_{1:t})
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CTCBeamSearch::CalculateApproxScore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
||||
double approx_ctc = prefixes_[i]->score;
|
||||
if (init_ext_scorer_ != nullptr) {
|
||||
vector<int> output;
|
||||
prefixes_[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = init_ext_scorer_->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
|
||||
init_ext_scorer_->alpha;
|
||||
}
|
||||
prefixes_[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
}
|
||||
|
||||
void CTCBeamSearch::LMRescore() {
|
||||
size_t beam_size = opts_.beam_size;
|
||||
if (init_ext_scorer_ != nullptr &&
|
||||
!init_ext_scorer_->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
||||
auto prefix = prefixes_[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||
float score = 0.0;
|
||||
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
||||
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
||||
init_ext_scorer_->alpha;
|
||||
score += init_ext_scorer_->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,94 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/common.h"
|
||||
#include "decoder/ctc_decoders/path_trie.h"
|
||||
#include "decoder/ctc_decoders/scorer.h"
|
||||
#include "nnet/decodable-itf.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct CTCBeamSearchOptions {
|
||||
std::string dict_file;
|
||||
std::string lm_path;
|
||||
BaseFloat alpha;
|
||||
BaseFloat beta;
|
||||
BaseFloat cutoff_prob;
|
||||
int beam_size;
|
||||
int cutoff_top_n;
|
||||
int num_proc_bsearch;
|
||||
CTCBeamSearchOptions()
|
||||
: dict_file("vocab.txt"),
|
||||
lm_path("lm.klm"),
|
||||
alpha(1.9f),
|
||||
beta(5.0),
|
||||
beam_size(300),
|
||||
cutoff_prob(0.99f),
|
||||
cutoff_top_n(40),
|
||||
num_proc_bsearch(0) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register("dict", &dict_file, "dict file ");
|
||||
opts->Register("lm-path", &lm_path, "language model file");
|
||||
opts->Register("alpha", &alpha, "alpha");
|
||||
opts->Register("beta", &beta, "beta");
|
||||
opts->Register(
|
||||
"beam-size", &beam_size, "beam size for beam search method");
|
||||
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
|
||||
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
|
||||
opts->Register(
|
||||
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
|
||||
}
|
||||
};
|
||||
|
||||
class CTCBeamSearch {
|
||||
public:
|
||||
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
|
||||
~CTCBeamSearch() {}
|
||||
void InitDecoder();
|
||||
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
|
||||
std::string GetBestPath();
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
std::string GetFinalBestPath();
|
||||
int NumFrameDecoded();
|
||||
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||
std::vector<std::string>& nbest_words);
|
||||
void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
void ResetPrefixes();
|
||||
int32 SearchOneChar(const bool& full_beam,
|
||||
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
||||
const BaseFloat& min_cutoff);
|
||||
void CalculateApproxScore();
|
||||
void LMRescore();
|
||||
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
|
||||
|
||||
CTCBeamSearchOptions opts_;
|
||||
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
|
||||
std::vector<std::string> vocabulary_; // todo remove later
|
||||
size_t blank_id_;
|
||||
int space_id_;
|
||||
std::shared_ptr<PathTrie> root_;
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
int num_frame_decoded_;
|
||||
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
|
||||
};
|
||||
|
||||
} // namespace basr
|
@ -0,0 +1 @@
|
||||
../../../third_party/ctc_decoders
|
@ -0,0 +1,10 @@
|
||||
project(frontend)
|
||||
|
||||
add_library(frontend STATIC
|
||||
normalizer.cc
|
||||
linear_spectrogram.cc
|
||||
raw_audio.cc
|
||||
feature_cache.cc
|
||||
)
|
||||
|
||||
target_link_libraries(frontend PUBLIC kaldi-matrix)
|
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// wrap the fbank feat of kaldi, todo (SmileGoat)
|
||||
|
||||
#include "kaldi/feat/feature-mfcc.h"
|
||||
|
||||
#incldue "kaldi/matrix/kaldi-vector.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class FbankExtractor : FeatureExtractorInterface {
|
||||
public:
|
||||
explicit FbankExtractor(const FbankOptions& opts,
|
||||
share_ptr<FeatureExtractorInterface> pre_extractor);
|
||||
virtual void AcceptWaveform(
|
||||
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
|
||||
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
|
||||
virtual size_t Dim() const = 0;
|
||||
|
||||
private:
|
||||
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
|
||||
kaldi::Vector<kaldi::BaseFloat>* feat) const;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "frontend/feature_cache.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::BaseFloat;
|
||||
using std::vector;
|
||||
using kaldi::SubVector;
|
||||
using std::unique_ptr;
|
||||
|
||||
FeatureCache::FeatureCache(
|
||||
int max_size, unique_ptr<FeatureExtractorInterface> base_extractor) {
|
||||
max_size_ = max_size;
|
||||
base_extractor_ = std::move(base_extractor);
|
||||
}
|
||||
|
||||
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
|
||||
base_extractor_->Accept(inputs);
|
||||
// feed current data
|
||||
bool result = false;
|
||||
do {
|
||||
result = Compute();
|
||||
} while (result);
|
||||
}
|
||||
|
||||
// pop feature chunk
|
||||
bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
|
||||
kaldi::Timer timer;
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (cache_.empty() && base_extractor_->IsFinished() == false) {
|
||||
ready_read_condition_.wait(lock);
|
||||
BaseFloat elapsed = timer.Elapsed() * 1000;
|
||||
// todo replace 1.0 with timeout_
|
||||
if (elapsed > 1.0) {
|
||||
return false;
|
||||
}
|
||||
usleep(1000); // sleep 1 ms
|
||||
}
|
||||
if (cache_.empty()) return false;
|
||||
feats->Resize(cache_.front().Dim());
|
||||
feats->CopyFromVec(cache_.front());
|
||||
cache_.pop();
|
||||
ready_feed_condition_.notify_one();
|
||||
return true;
|
||||
}
|
||||
|
||||
// read all data from base_feature_extractor_ into cache_
|
||||
bool FeatureCache::Compute() {
|
||||
// compute and feed
|
||||
Vector<BaseFloat> feature_chunk;
|
||||
bool result = base_extractor_->Read(&feature_chunk);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (cache_.size() >= max_size_) {
|
||||
ready_feed_condition_.wait(lock);
|
||||
}
|
||||
if (feature_chunk.Dim() != 0) {
|
||||
cache_.push(feature_chunk);
|
||||
}
|
||||
ready_read_condition_.notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
// std::lock_guard<std::mutex> lock(mutex_);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,57 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class FeatureCache : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit FeatureCache(
|
||||
int32 max_size = kint16max,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor = NULL);
|
||||
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
|
||||
// feats dim = num_frames * feature_dim
|
||||
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
|
||||
// feature cache only cache feature which from base extractor
|
||||
virtual size_t Dim() const { return base_extractor_->Dim(); }
|
||||
virtual void SetFinished() {
|
||||
base_extractor_->SetFinished();
|
||||
// read the last chunk data
|
||||
Compute();
|
||||
}
|
||||
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
|
||||
virtual void Reset() {
|
||||
base_extractor_->Reset();
|
||||
while (!cache_.empty()) {
|
||||
cache_.pop();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool Compute();
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t max_size_;
|
||||
std::queue<kaldi::Vector<BaseFloat>> cache_;
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
std::condition_variable ready_feed_condition_;
|
||||
std::condition_variable ready_read_condition_;
|
||||
// DISALLOW_COPY_AND_ASSGIN(FeatureCache);
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,13 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "kaldi/matrix/kaldi-vector.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
class FeatureExtractorInterface {
|
||||
public:
|
||||
// accept input data, accept feature or raw waves which decided
|
||||
// by the base_extractor
|
||||
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) = 0;
|
||||
// get the processed result
|
||||
// the length of output = feature_row * feature_dim,
|
||||
// the Matrix is squashed into Vector
|
||||
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* outputs) = 0;
|
||||
// the Dim is the feature dim
|
||||
virtual size_t Dim() const = 0;
|
||||
virtual void SetFinished() = 0;
|
||||
virtual bool IsFinished() const = 0;
|
||||
virtual void Reset() = 0;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,156 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "frontend/linear_spectrogram.h"
|
||||
#include "kaldi/base/kaldi-math.h"
|
||||
#include "kaldi/matrix/matrix-functions.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::int32;
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
LinearSpectrogram::LinearSpectrogram(
|
||||
const LinearSpectrogramOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
|
||||
opts_ = opts;
|
||||
base_extractor_ = std::move(base_extractor);
|
||||
int32 window_size = opts.frame_opts.WindowSize();
|
||||
int32 window_shift = opts.frame_opts.WindowShift();
|
||||
fft_points_ = window_size;
|
||||
chunk_sample_size_ =
|
||||
static_cast<int32>(opts.streaming_chunk * opts.frame_opts.samp_freq);
|
||||
hanning_window_.resize(window_size);
|
||||
|
||||
double a = M_2PI / (window_size - 1);
|
||||
hanning_window_energy_ = 0;
|
||||
for (int i = 0; i < window_size; ++i) {
|
||||
hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
|
||||
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
|
||||
}
|
||||
|
||||
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
|
||||
}
|
||||
|
||||
void LinearSpectrogram::Accept(const VectorBase<BaseFloat>& inputs) {
|
||||
base_extractor_->Accept(inputs);
|
||||
}
|
||||
|
||||
bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
|
||||
Vector<BaseFloat> input_feats(chunk_sample_size_);
|
||||
bool flag = base_extractor_->Read(&input_feats);
|
||||
if (flag == false || input_feats.Dim() == 0) return false;
|
||||
|
||||
vector<BaseFloat> input_feats_vec(input_feats.Dim());
|
||||
std::memcpy(input_feats_vec.data(),
|
||||
input_feats.Data(),
|
||||
input_feats.Dim() * sizeof(BaseFloat));
|
||||
vector<vector<BaseFloat>> result;
|
||||
Compute(input_feats_vec, result);
|
||||
int32 feat_size = 0;
|
||||
if (result.size() != 0) {
|
||||
feat_size = result.size() * result[0].size();
|
||||
}
|
||||
feats->Resize(feat_size);
|
||||
// todo refactor (SimleGoat)
|
||||
for (size_t idx = 0; idx < feat_size; ++idx) {
|
||||
(*feats)(idx) = result[idx / dim_][idx % dim_];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void LinearSpectrogram::Hanning(vector<float>* data) const {
|
||||
CHECK_GE(data->size(), hanning_window_.size());
|
||||
|
||||
for (size_t i = 0; i < hanning_window_.size(); ++i) {
|
||||
data->at(i) *= hanning_window_[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
|
||||
vector<BaseFloat>* real,
|
||||
vector<BaseFloat>* img) const {
|
||||
Vector<BaseFloat> v_tmp;
|
||||
v_tmp.Resize(v->size());
|
||||
std::memcpy(v_tmp.Data(), v->data(), sizeof(BaseFloat) * (v->size()));
|
||||
RealFft(&v_tmp, true);
|
||||
v->resize(v_tmp.Dim());
|
||||
std::memcpy(v->data(), v_tmp.Data(), sizeof(BaseFloat) * (v->size()));
|
||||
|
||||
real->push_back(v->at(0));
|
||||
img->push_back(0);
|
||||
for (int i = 1; i < v->size() / 2; i++) {
|
||||
real->push_back(v->at(2 * i));
|
||||
img->push_back(v->at(2 * i + 1));
|
||||
}
|
||||
real->push_back(v->at(1));
|
||||
img->push_back(0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Compute spectrogram feat
|
||||
// todo: refactor later (SmileGoat)
|
||||
bool LinearSpectrogram::Compute(const vector<float>& waves,
|
||||
vector<vector<float>>& feats) {
|
||||
int num_samples = waves.size();
|
||||
const int& frame_length = opts_.frame_opts.WindowSize();
|
||||
const int& sample_rate = opts_.frame_opts.samp_freq;
|
||||
const int& frame_shift = opts_.frame_opts.WindowShift();
|
||||
const int& fft_points = fft_points_;
|
||||
const float scale = hanning_window_energy_ * sample_rate;
|
||||
|
||||
if (num_samples < frame_length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
|
||||
feats.resize(num_frames);
|
||||
vector<float> fft_real((fft_points_ / 2 + 1), 0);
|
||||
vector<float> fft_img((fft_points_ / 2 + 1), 0);
|
||||
vector<float> v(frame_length, 0);
|
||||
vector<float> power((fft_points / 2 + 1));
|
||||
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
vector<float> data(waves.data() + i * frame_shift,
|
||||
waves.data() + i * frame_shift + frame_length);
|
||||
Hanning(&data);
|
||||
fft_img.clear();
|
||||
fft_real.clear();
|
||||
v.assign(data.begin(), data.end());
|
||||
NumpyFft(&v, &fft_real, &fft_img);
|
||||
|
||||
feats[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
|
||||
for (int j = 0; j < (fft_points / 2 + 1); ++j) {
|
||||
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
|
||||
feats[i][j] = power[j];
|
||||
|
||||
if (j == 0 || j == feats[0].size() - 1) {
|
||||
feats[i][j] /= scale;
|
||||
} else {
|
||||
feats[i][j] *= (2.0 / scale);
|
||||
}
|
||||
|
||||
// log added eps=1e-14
|
||||
feats[i][j] = std::log(feats[i][j] + 1e-14);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "kaldi/feat/feature-window.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct LinearSpectrogramOptions {
|
||||
kaldi::FrameExtractionOptions frame_opts;
|
||||
kaldi::BaseFloat streaming_chunk;
|
||||
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register(
|
||||
"streaming-chunk", &streaming_chunk, "streaming chunk size");
|
||||
frame_opts.Register(opts);
|
||||
}
|
||||
};
|
||||
|
||||
class LinearSpectrogram : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit LinearSpectrogram(
|
||||
const LinearSpectrogramOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor);
|
||||
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
|
||||
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
|
||||
// the dim_ is the dim of single frame feature
|
||||
virtual size_t Dim() const { return dim_; }
|
||||
virtual void SetFinished() { base_extractor_->SetFinished(); }
|
||||
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
|
||||
virtual void Reset() { base_extractor_->Reset(); }
|
||||
|
||||
private:
|
||||
void Hanning(std::vector<kaldi::BaseFloat>* data) const;
|
||||
bool Compute(const std::vector<kaldi::BaseFloat>& waves,
|
||||
std::vector<std::vector<kaldi::BaseFloat>>& feats);
|
||||
bool NumpyFft(std::vector<kaldi::BaseFloat>* v,
|
||||
std::vector<kaldi::BaseFloat>* real,
|
||||
std::vector<kaldi::BaseFloat>* img) const;
|
||||
|
||||
kaldi::int32 fft_points_;
|
||||
size_t dim_;
|
||||
std::vector<kaldi::BaseFloat> hanning_window_;
|
||||
kaldi::BaseFloat hanning_window_energy_;
|
||||
LinearSpectrogramOptions opts_;
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
int chunk_sample_size_;
|
||||
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// wrap the mfcc feat of kaldi, todo (SmileGoat)
|
||||
#include "kaldi/feat/feature-mfcc.h"
|
@ -0,0 +1,188 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "frontend/normalizer.h"
|
||||
#include "kaldi/feat/cmvn.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::Vector;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::BaseFloat;
|
||||
using std::vector;
|
||||
using kaldi::SubVector;
|
||||
using std::unique_ptr;
|
||||
|
||||
DecibelNormalizer::DecibelNormalizer(
|
||||
const DecibelNormalizerOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor) {
|
||||
base_extractor_ = std::move(base_extractor);
|
||||
opts_ = opts;
|
||||
dim_ = 1;
|
||||
}
|
||||
|
||||
void DecibelNormalizer::Accept(const kaldi::VectorBase<BaseFloat>& waves) {
|
||||
base_extractor_->Accept(waves);
|
||||
}
|
||||
|
||||
bool DecibelNormalizer::Read(kaldi::Vector<BaseFloat>* waves) {
|
||||
if (base_extractor_->Read(waves) == false || waves->Dim() == 0) {
|
||||
return false;
|
||||
}
|
||||
Compute(waves);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DecibelNormalizer::Compute(VectorBase<BaseFloat>* waves) const {
|
||||
// calculate db rms
|
||||
BaseFloat rms_db = 0.0;
|
||||
BaseFloat mean_square = 0.0;
|
||||
BaseFloat gain = 0.0;
|
||||
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
|
||||
|
||||
vector<BaseFloat> samples;
|
||||
samples.resize(waves->Dim());
|
||||
for (size_t i = 0; i < samples.size(); ++i) {
|
||||
samples[i] = (*waves)(i);
|
||||
}
|
||||
|
||||
// square
|
||||
for (auto& d : samples) {
|
||||
if (opts_.convert_int_float) {
|
||||
d = d * wave_float_normlization;
|
||||
}
|
||||
mean_square += d * d;
|
||||
}
|
||||
|
||||
// mean
|
||||
mean_square /= samples.size();
|
||||
rms_db = 10 * std::log10(mean_square);
|
||||
gain = opts_.target_db - rms_db;
|
||||
|
||||
if (gain > opts_.max_gain_db) {
|
||||
LOG(ERROR)
|
||||
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
|
||||
<< "because the the probable gain have exceeds opts_.max_gain_db"
|
||||
<< opts_.max_gain_db << "dB.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note that this is an in-place transformation.
|
||||
for (auto& item : samples) {
|
||||
// python item *= 10.0 ** (gain / 20.0)
|
||||
item *= std::pow(10.0, gain / 20.0);
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
waves->Data(), samples.data(), sizeof(BaseFloat) * samples.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
CMVN::CMVN(std::string cmvn_file,
|
||||
unique_ptr<FeatureExtractorInterface> base_extractor)
|
||||
: var_norm_(true) {
|
||||
base_extractor_ = std::move(base_extractor);
|
||||
bool binary;
|
||||
kaldi::Input ki(cmvn_file, &binary);
|
||||
stats_.Read(ki.Stream(), binary);
|
||||
dim_ = stats_.NumCols() - 1;
|
||||
}
|
||||
|
||||
void CMVN::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
|
||||
base_extractor_->Accept(inputs);
|
||||
return;
|
||||
}
|
||||
|
||||
bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
|
||||
if (base_extractor_->Read(feats) == false) {
|
||||
return false;
|
||||
}
|
||||
Compute(feats);
|
||||
return true;
|
||||
}
|
||||
|
||||
// feats contain num_frames feature.
|
||||
void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
|
||||
KALDI_ASSERT(feats != NULL);
|
||||
int32 dim = stats_.NumCols() - 1;
|
||||
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
|
||||
feats->Dim() % dim != 0) {
|
||||
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
|
||||
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x';
|
||||
}
|
||||
if (stats_.NumRows() == 1 && var_norm_) {
|
||||
KALDI_ERR
|
||||
<< "You requested variance normalization but no variance stats_ "
|
||||
<< "are supplied.";
|
||||
}
|
||||
|
||||
double count = stats_(0, dim);
|
||||
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
|
||||
// computing an offset and representing it as stats_, we use a count of one.
|
||||
if (count < 1.0)
|
||||
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
|
||||
"normalization: "
|
||||
<< "count = " << count;
|
||||
|
||||
if (!var_norm_) {
|
||||
Vector<BaseFloat> offset(feats->Dim());
|
||||
SubVector<double> mean_stats(stats_.RowData(0), dim);
|
||||
Vector<double> mean_stats_apply(feats->Dim());
|
||||
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
|
||||
// with the dim of feature.
|
||||
// the dim of feats = dim * num_frames;
|
||||
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
|
||||
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
|
||||
dim);
|
||||
stats_tmp.CopyFromVec(mean_stats);
|
||||
}
|
||||
offset.AddVec(-1.0 / count, mean_stats_apply);
|
||||
feats->AddVec(1.0, offset);
|
||||
return;
|
||||
}
|
||||
// norm(0, d) = mean offset;
|
||||
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
|
||||
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
|
||||
for (int32 d = 0; d < dim; d++) {
|
||||
double mean, offset, scale;
|
||||
mean = stats_(0, d) / count;
|
||||
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
|
||||
if (var < floor) {
|
||||
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
|
||||
<< floor;
|
||||
var = floor;
|
||||
}
|
||||
scale = 1.0 / sqrt(var);
|
||||
if (scale != scale || 1 / scale == 0.0)
|
||||
KALDI_ERR
|
||||
<< "NaN or infinity in cepstral mean/variance computation";
|
||||
offset = -(mean * scale);
|
||||
for (int32 d_skip = d; d_skip < feats->Dim();) {
|
||||
norm(0, d_skip) = offset;
|
||||
norm(1, d_skip) = scale;
|
||||
d_skip = d_skip + dim;
|
||||
}
|
||||
}
|
||||
// Apply the normalization.
|
||||
feats->MulElements(norm.Row(1));
|
||||
feats->AddVec(1.0, norm.Row(0));
|
||||
}
|
||||
|
||||
void CMVN::ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats) {
|
||||
ApplyCmvn(stats_, var_norm_, feats);
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,89 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/common.h"
|
||||
#include "frontend/feature_extractor_interface.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/options-itf.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct DecibelNormalizerOptions {
|
||||
float target_db;
|
||||
float max_gain_db;
|
||||
bool convert_int_float;
|
||||
DecibelNormalizerOptions()
|
||||
: target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
|
||||
|
||||
void Register(kaldi::OptionsItf* opts) {
|
||||
opts->Register(
|
||||
"target-db", &target_db, "target db for db normalization");
|
||||
opts->Register(
|
||||
"max-gain-db", &max_gain_db, "max gain db for db normalization");
|
||||
opts->Register("convert-int-float",
|
||||
&convert_int_float,
|
||||
"if convert int samples to float");
|
||||
}
|
||||
};
|
||||
|
||||
class DecibelNormalizer : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit DecibelNormalizer(
|
||||
const DecibelNormalizerOptions& opts,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor);
|
||||
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
|
||||
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
|
||||
// noramlize audio, the dim is 1.
|
||||
virtual size_t Dim() const { return dim_; }
|
||||
virtual void SetFinished() { base_extractor_->SetFinished(); }
|
||||
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
|
||||
virtual void Reset() { base_extractor_->Reset(); }
|
||||
|
||||
private:
|
||||
bool Compute(kaldi::VectorBase<kaldi::BaseFloat>* waves) const;
|
||||
DecibelNormalizerOptions opts_;
|
||||
size_t dim_;
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
kaldi::Vector<kaldi::BaseFloat> waveform_;
|
||||
};
|
||||
|
||||
|
||||
class CMVN : public FeatureExtractorInterface {
|
||||
public:
|
||||
explicit CMVN(std::string cmvn_file,
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor);
|
||||
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs);
|
||||
|
||||
// the length of feats = feature_row * feature_dim,
|
||||
// the Matrix is squashed into Vector
|
||||
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
|
||||
// the dim_ is the feautre dim.
|
||||
virtual size_t Dim() const { return dim_; }
|
||||
virtual void SetFinished() { base_extractor_->SetFinished(); }
|
||||
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
|
||||
virtual void Reset() { base_extractor_->Reset(); }
|
||||
|
||||
private:
|
||||
void Compute(kaldi::VectorBase<kaldi::BaseFloat>* feats) const;
|
||||
void ApplyCMVN(kaldi::MatrixBase<BaseFloat>* feats);
|
||||
kaldi::Matrix<double> stats_;
|
||||
std::unique_ptr<FeatureExtractorInterface> base_extractor_;
|
||||
size_t dim_;
|
||||
bool var_norm_;
|
||||
};
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "frontend/raw_audio.h"
|
||||
#include "kaldi/base/timer.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::VectorBase;
|
||||
using kaldi::Vector;
|
||||
|
||||
RawAudioCache::RawAudioCache(int buffer_size)
|
||||
: finished_(false), data_length_(0), start_(0), timeout_(1) {
|
||||
ring_buffer_.resize(buffer_size);
|
||||
}
|
||||
|
||||
void RawAudioCache::Accept(const VectorBase<BaseFloat>& waves) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (data_length_ + waves.Dim() > ring_buffer_.size()) {
|
||||
ready_feed_condition_.wait(lock);
|
||||
}
|
||||
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
|
||||
int32 buffer_idx = (idx + start_) % ring_buffer_.size();
|
||||
ring_buffer_[buffer_idx] = waves(idx);
|
||||
}
|
||||
data_length_ += waves.Dim();
|
||||
}
|
||||
|
||||
bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
|
||||
size_t chunk_size = waves->Dim();
|
||||
kaldi::Timer timer;
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (chunk_size > data_length_) {
|
||||
// when audio is empty and no more data feed
|
||||
// ready_read_condition will block in dead lock. so replace with
|
||||
// timeout_
|
||||
// ready_read_condition_.wait(lock);
|
||||
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
|
||||
if (elapsed > timeout_) {
|
||||
if (finished_ == true) { // read last chunk data
|
||||
break;
|
||||
}
|
||||
if (chunk_size > data_length_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
usleep(100); // sleep 0.1 ms
|
||||
}
|
||||
|
||||
// read last chunk data
|
||||
if (chunk_size > data_length_) {
|
||||
chunk_size = data_length_;
|
||||
waves->Resize(chunk_size);
|
||||
}
|
||||
|
||||
for (size_t idx = 0; idx < chunk_size; ++idx) {
|
||||
int buff_idx = (start_ + idx) % ring_buffer_.size();
|
||||
waves->Data()[idx] = ring_buffer_[buff_idx];
|
||||
}
|
||||
data_length_ -= chunk_size;
|
||||
start_ = (start_ + chunk_size) % ring_buffer_.size();
|
||||
ready_feed_condition_.notify_one();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ppspeech
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue