Merge branch 'develop' into cluster

pull/1681/head
qingen 2 years ago committed by GitHub
commit 159d8fd628
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -90,7 +90,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py
python src/audio_search.py
```
Then you will see the Application is started:
@ -111,7 +111,7 @@ Then to start the system server, and it provides HTTP backend services.
```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
```
**Note**: If you want to build a quick demo, you can use ./src/test_main.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
**Note**: If you want to build a quick demo, you can use ./src/test_audio_search.py:download_audio_data function, it downloads 20 audio files , Subsequent results show this collection as an example
- Prepare model(Skip this step if you use the default model.)
```bash
@ -123,7 +123,7 @@ Then to start the system server, and it provides HTTP backend services.
The internal process is downloading data, loading the paddlespeech model, extracting embedding, storing library, retrieving and deleting library
```bash
python ./src/test_main.py
python ./src/test_audio_search.py
```
Output

@ -92,7 +92,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash
export PYTHONPATH=$PYTHONPATH:./src:../../paddleaudio
python src/main.py
python src/audio_search.py
```
然后你会看到应用程序启动:
@ -113,7 +113,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
```bash
wget -c https://www.openslr.org/resources/82/cn-celeb_v2.tar.gz && tar -xvf cn-celeb_v2.tar.gz
```
**注**:如果希望快速搭建 demo可以采用 ./src/test_main.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
**注**:如果希望快速搭建 demo可以采用 ./src/test_audio_search.py:download_audio_data 内部的 20 条音频,另外后续结果展示以该集合为例
- 准备模型(如果使用默认模型,可以跳过此步骤)
```bash
@ -124,7 +124,7 @@ ffce340b3790 minio/minio:RELEASE.2020-12-03T00-03-10Z "/usr/bin/docker-ent…"
- 脚本测试(推荐)
```bash
python ./src/test_main.py
python ./src/test_audio_search.py
```
注:内部将依次下载数据,加载 paddlespeech 模型,提取 embedding存储建库检索删库

@ -40,7 +40,6 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"])
MODEL = None
MILVUS_CLI = MilvusHelper()
MYSQL_CLI = MySQLHelper()

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from logs import LOGGER
from paddlespeech.cli import VectorExecutor
vector_executor = VectorExecutor()

@ -13,6 +13,7 @@
# limitations under the License.
import sys
import numpy
import pymysql
from config import MYSQL_DB
from config import MYSQL_HOST
@ -69,7 +70,7 @@ class MySQLHelper():
sys.exit(1)
def load_data_to_mysql(self, table_name, data):
# Batch insert (Milvus_ids, img_path) to mysql
# Batch insert (Milvus_ids, audio_path) to mysql
self.test_connection()
sql = "insert into " + table_name + " (milvus_id,audio_path) values (%s,%s);"
try:
@ -82,7 +83,7 @@ class MySQLHelper():
sys.exit(1)
def search_by_milvus_ids(self, ids, table_name):
# Get the img_path according to the milvus ids
# Get the audio_path according to the milvus ids
self.test_connection()
str_ids = str(ids).replace('[', '').replace(']', '')
sql = "select audio_path from " + table_name + " where milvus_id in (" + str_ids + ") order by field (milvus_id," + str_ids + ");"
@ -120,14 +121,83 @@ class MySQLHelper():
sys.exit(1)
def count_table(self, table_name):
# Get the number of mysql table
# Get the number of spk in mysql table
self.test_connection()
sql = "select count(milvus_id) from " + table_name + ";"
sql = "select count(spk_id) from " + table_name + ";"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
LOGGER.debug(f"MYSQL count table:{table_name}")
LOGGER.debug(f"MYSQL count table:{results[0][0]}")
return results[0][0]
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def create_mysql_table_vpr(self, table_name):
# Create mysql table if not exists
self.test_connection()
sql = "create table if not exists " + table_name + "(spk_id TEXT, audio_path TEXT, embedding TEXT);"
try:
self.cursor.execute(sql)
LOGGER.debug(f"MYSQL create table: {table_name} with sql: {sql}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def load_data_to_mysql_vpr(self, table_name, data):
# Insert (spk, audio, embedding) to mysql
self.test_connection()
sql = "insert into " + table_name + " (spk_id,audio_path,embedding) values (%s,%s,%s);"
try:
self.cursor.execute(sql, data)
LOGGER.debug(
f"MYSQL loads data to table: {table_name} successfully")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def list_vpr(self, table_name):
# Get all records in mysql
self.test_connection()
sql = "select * from " + table_name + " ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
self.conn.commit()
spk_ids = [res[0] for res in results]
audio_paths = [res[1] for res in results]
embeddings = [
numpy.array(
str(res[2]).replace('[', '').replace(']', '').split(","))
for res in results
]
return spk_ids, audio_paths, embeddings
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def search_audio_vpr(self, table_name, spk_id):
# Get the audio_path according to the spk_id
self.test_connection()
sql = "select audio_path from " + table_name + " where spk_id='" + spk_id + "' ;"
try:
self.cursor.execute(sql)
results = self.cursor.fetchall()
LOGGER.debug(
f"MYSQL search by spk id {spk_id} to get audio {results[0][0]}.")
return results[0][0]
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)
def delete_data_vpr(self, table_name, spk_id):
# Delete a record by spk_id in mysql table
self.test_connection()
sql = "delete from " + table_name + " where spk_id='" + spk_id + "';"
try:
self.cursor.execute(sql)
LOGGER.debug(
f"MYSQL delete a record {spk_id} in table {table_name}")
except Exception as e:
LOGGER.error(f"MYSQL ERROR: {e} with sql: {sql}")
sys.exit(1)

@ -31,3 +31,45 @@ def do_count(table_name, milvus_cli):
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_count_vpr(table_name, mysql_cli):
"""
Returns the total number of spk in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
num = mysql_cli.count_table(table_name)
return num
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_list(table_name, mysql_cli):
"""
Returns the total records of vpr in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
spk_ids, audio_paths, _ = mysql_cli.list_vpr(table_name)
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)
def do_get(table_name, spk_id, mysql_cli):
"""
Returns the audio path by spk_id in the system
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
audio_apth = mysql_cli.search_audio_vpr(table_name, spk_id)
return audio_apth
except Exception as e:
LOGGER.error(f"Error attempting to count table {e}")
sys.exit(1)

@ -32,3 +32,31 @@ def do_drop(table_name, milvus_cli, mysql_cli):
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
def do_drop_vpr(table_name, mysql_cli):
"""
Delete the table of MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_table(table_name)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)
def do_delete(table_name, spk_id, mysql_cli):
"""
Delete a record by spk_id in MySQL
"""
if not table_name:
table_name = DEFAULT_TABLE
try:
mysql_cli.delete_data_vpr(table_name, spk_id)
return "OK"
except Exception as e:
LOGGER.error(f"Error attempting to drop table: {e}")
sys.exit(1)

@ -82,3 +82,16 @@ def do_load(table_name, audio_dir, milvus_cli, mysql_cli):
mysql_cli.create_mysql_table(table_name)
mysql_cli.load_data_to_mysql(table_name, format_data(ids, names))
return len(ids)
def do_enroll(table_name, spk_id, audio_path, mysql_cli):
"""
Import spk_id,audio_path,embedding to Mysql
"""
if not table_name:
table_name = DEFAULT_TABLE
embedding = get_audio_embedding(audio_path)
mysql_cli.create_mysql_table_vpr(table_name)
data = (spk_id, audio_path, str(embedding))
mysql_cli.load_data_to_mysql_vpr(table_name, data)
return "OK"

@ -13,6 +13,7 @@
# limitations under the License.
import sys
import numpy
from config import DEFAULT_TABLE
from config import TOP_K
from encode import get_audio_embedding
@ -39,3 +40,26 @@ def do_search(host, table_name, audio_path, milvus_cli, mysql_cli):
except Exception as e:
LOGGER.error(f"Error with search: {e}")
sys.exit(1)
def do_search_vpr(host, table_name, audio_path, mysql_cli):
"""
Search the uploaded audio in MySQL
"""
try:
if not table_name:
table_name = DEFAULT_TABLE
emb = get_audio_embedding(audio_path)
emb = numpy.array(emb)
spk_ids, paths, vectors = mysql_cli.list_vpr(table_name)
scores = [numpy.dot(emb, x.astype(numpy.float64)) for x in vectors]
spk_ids = [str(x) for x in spk_ids]
paths = [str(x) for x in paths]
for i in range(len(paths)):
tmp = "http://" + str(host) + "/data?audio_path=" + str(paths[i])
paths[i] = tmp
scores[i] = scores[i] * 100
return spk_ids, paths, scores
except Exception as e:
LOGGER.error(f"Error with search: {e}")
sys.exit(1)

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from audio_search import app
from fastapi.testclient import TestClient
from main import app
from utils.utility import download
from utils.utility import unpack
@ -22,7 +22,7 @@ client = TestClient(app)
def download_audio_data():
"""
download audio data
Download audio data
"""
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
@ -64,7 +64,7 @@ def test_count():
"""
Returns the total number of vectors in the system
"""
response = client.get("audio/count")
response = client.get("/audio/count")
assert response.status_code == 200
assert response.json() == 20

@ -0,0 +1,115 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi.testclient import TestClient
from vpr_search import app
from utils.utility import download
from utils.utility import unpack
client = TestClient(app)
def download_audio_data():
"""
Download audio data
"""
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
target_dir = "./"
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir, True)
def test_drop():
"""
Delete the table of MySQL
"""
response = client.post("/vpr/drop")
assert response.status_code == 200
def test_enroll_local(spk: str, audio: str):
"""
Enroll the audio to MySQL
"""
response = client.post("/vpr/enroll/local?spk_id=" + spk +
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
assert response.status_code == 200
assert response.json() == {
'status': True,
'msg': "Successfully enroll data!"
}
def test_search_local():
"""
Search the spk in MySQL by audio
"""
response = client.post(
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
assert response.status_code == 200
def test_list():
"""
Get all records in MySQL
"""
response = client.get("/vpr/list")
assert response.status_code == 200
def test_data(spk: str):
"""
Get the audio file by spk_id in MySQL
"""
response = client.get("/vpr/data?spk_id=" + spk)
assert response.status_code == 200
def test_del(spk: str):
"""
Delete the record in MySQL by spk_id
"""
response = client.post("/vpr/del?spk_id=" + spk)
assert response.status_code == 200
def test_count():
"""
Get the number of spk in MySQL
"""
response = client.get("/vpr/count")
assert response.status_code == 200
if __name__ == "__main__":
download_audio_data()
test_enroll_local("spk1", "arms_strikes")
test_enroll_local("spk2", "sword_wielding")
test_enroll_local("spk3", "test")
test_list()
test_data("spk1")
test_count()
test_search_local()
test_del("spk1")
test_count()
test_search_local()
test_enroll_local("spk1", "arms_strikes")
test_count()
test_search_local()
test_drop()

@ -0,0 +1,206 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uvicorn
from config import UPLOAD_PATH
from fastapi import FastAPI
from fastapi import File
from fastapi import UploadFile
from logs import LOGGER
from mysql_helpers import MySQLHelper
from operations.count import do_count_vpr
from operations.count import do_get
from operations.count import do_list
from operations.drop import do_delete
from operations.drop import do_drop_vpr
from operations.load import do_enroll
from operations.search import do_search_vpr
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import FileResponse
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"])
MYSQL_CLI = MySQLHelper()
# Mkdir 'tmp/audio-data'
if not os.path.exists(UPLOAD_PATH):
os.makedirs(UPLOAD_PATH)
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
@app.post('/vpr/enroll')
async def vpr_enroll(table_name: str=None,
spk_id: str=None,
audio: UploadFile=File(...)):
# Enroll the uploaded audio with spk-id into MySQL
try:
# Save the upload data to server.
content = await audio.read()
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(audio_path, "wb+") as f:
f.write(content)
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} online!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/enroll/local')
async def vpr_enroll_local(table_name: str=None,
spk_id: str=None,
audio_path: str=None):
# Enroll the local audio with spk-id into MySQL
try:
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
return {'status': True, 'msg': "Successfully enroll data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog')
async def vpr_recog(request: Request,
table_name: str=None,
audio: UploadFile=File(...)):
# Voice print recognition online
try:
# Save the upload data to server.
content = await audio.read()
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
with open(query_audio_path, "wb+") as f:
f.write(content)
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name,
query_audio_path, MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition online!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/recog/local')
async def vpr_recog_local(request: Request,
table_name: str=None,
audio_path: str=None):
# Voice print recognition locally
try:
host = request.headers['host']
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
MYSQL_CLI)
for spk_id, path, score in zip(spk_ids, paths, scores):
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
res = dict(zip(spk_ids, zip(paths, scores)))
# Sort results by distance metric, closest distances first
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
LOGGER.info("Successfully speaker recognition locally!")
return res
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/del')
async def vpr_del(table_name: str=None, spk_id: str=None):
# Delete a record by spk_id in MySQL
try:
do_delete(table_name, spk_id, MYSQL_CLI)
LOGGER.info("Successfully delete a record by spk_id in MySQL")
return {'status': True, 'msg': "Successfully delete data!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/list')
async def vpr_list(table_name: str=None):
# Get all records in MySQL
try:
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
for i in range(len(spk_ids)):
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
LOGGER.info("Successfully list all records from mysql!")
return spk_ids, audio_paths
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/data')
async def vpr_data(
table_name: str=None,
spk_id: str=None, ):
# Get the audio file from path by spk_id in MySQL
try:
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
LOGGER.info(f"Successfully get audio path {audio_path}!")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/vpr/count')
async def vpr_count(table_name: str=None):
# Get the total number of spk in MySQL
try:
num = do_count_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully count the number of spk!")
return num
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.post('/vpr/drop')
async def drop_tables(table_name: str=None):
# Delete the table of MySQL
try:
do_drop_vpr(table_name, MYSQL_CLI)
LOGGER.info("Successfully drop tables in MySQL!")
return {'status': True, 'msg': "Successfully drop tables!"}
except Exception as e:
LOGGER.error(e)
return {'status': False, 'msg': e}, 400
@app.get('/data')
def audio_path(audio_path):
# Get the audio file from path
try:
LOGGER.info(f"Successfully get audio: {audio_path}")
return FileResponse(audio_path)
except Exception as e:
LOGGER.error(f"get audio error: {e}")
return {'status': False, 'msg': e}, 400
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=8002)

@ -7,4 +7,4 @@ paddlespeech asr --input ./zh.wav
# asr + punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc

@ -85,6 +85,10 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
- 命令行 (推荐使用)
```
paddlespeech_client asr --server_ip 127.0.0.1 --port 8090 --input ./zh.wav
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
```
使用帮助:
@ -191,7 +195,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespee
```
### 5. CLS 客户端使用方法
### 6. CLS 客户端使用方法
**注意:** 初次使用客户端时响应时间会略长
- 命令行 (推荐使用)
```

@ -84,7 +84,7 @@ setuptools.setup(
install_requires=[
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
'soundfile >= 0.9.0', 'colorlog', 'dtaidistance == 2.3.1', 'pathos'
],
],
extras_require={
'test': [
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',

@ -79,7 +79,6 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts = self.model.decode(
xs,
@ -129,6 +128,7 @@ if __name__ == "__main__":
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import base64
import io
import json
import logging
import os
import random
import time
@ -28,6 +30,7 @@ from ..executor import BaseExecutor
from ..util import cli_client_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.server.tests.asr.online.websocket_client import ASRAudioHandler
from paddlespeech.server.utils.audio_process import wav2pcm
from paddlespeech.server.utils.util import wav2base64
@ -230,6 +233,76 @@ class ASRClientExecutor(BaseExecutor):
return res
@cli_client_register(
name='paddlespeech_client.asr_online',
description='visit asr online service')
class ASRClientExecutor(BaseExecutor):
def __init__(self):
super(ASRClientExecutor, self).__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech_client.asr', add_help=True)
self.parser.add_argument(
'--server_ip', type=str, default='127.0.0.1', help='server ip')
self.parser.add_argument(
'--port', type=int, default=8091, help='server port')
self.parser.add_argument(
'--input',
type=str,
default=None,
help='Audio file to be recognized',
required=True)
self.parser.add_argument(
'--sample_rate', type=int, default=16000, help='audio sample rate')
self.parser.add_argument(
'--lang', type=str, default="zh_cn", help='language')
self.parser.add_argument(
'--audio_format', type=str, default="wav", help='audio format')
def execute(self, argv: List[str]) -> bool:
args = self.parser.parse_args(argv)
input_ = args.input
server_ip = args.server_ip
port = args.port
sample_rate = args.sample_rate
lang = args.lang
audio_format = args.audio_format
try:
time_start = time.time()
res = self(
input=input_,
server_ip=server_ip,
port=port,
sample_rate=sample_rate,
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
return False
@stats_wrapper
def __call__(self,
input: str,
server_ip: str="127.0.0.1",
port: int=8091,
sample_rate: int=16000,
lang: str="zh_cn",
audio_format: str="wav"):
"""
Python API to call an executor.
"""
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):

@ -4,7 +4,7 @@
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8091
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']

@ -0,0 +1,49 @@
([简体中文](./README_cn.md)|English)
# 语音服务
## 介绍
本文档介绍如何使用流式ASR的三种不同客户端:网页、麦克风、Python模拟流式服务。
## 使用方法
### 1. 安装
请看 [安装文档](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install.md).
推荐使用 **paddlepaddle 2.2.1** 或以上版本。
你可以从 mediumhard 三中方式中选择一种方式安装 PaddleSpeech。
### 2. 准备测试文件
这个 ASR client 的输入应该是一个 WAV 文件(`.wav`),并且采样率必须与模型的采样率相同。
可以下载此 ASR client的示例音频
```bash
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
### 2. 流式 ASR 客户端使用方法
- Python模拟流式服务命令行
```
# 流式ASR
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8091 --input ./zh.wav
```
- 麦克风
```
# 直接调用麦克风设备
python microphone_client.py
```
- 网页
```
# 进入web目录后参考相关readme.md
```

@ -1,12 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: zhendong.peng@mobvoi.com (Zhendong Peng)
import argparse
from flask import Flask, render_template
from flask import Flask
from flask import render_template
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--port', default=19999, type=int, help='port id')
@ -14,9 +13,11 @@ args = parser.parse_args()
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port, debug=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 949 KiB

@ -0,0 +1,18 @@
# paddlespeech serving 网页Demo
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
## 使用方法
### 1. 在本地电脑启动网页服务
```
python app.py
```
### 2. 本地电脑浏览器
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
![图片](./paddle_web_demo.png)

@ -15,8 +15,10 @@
# -*- coding: UTF-8 -*-
import argparse
import asyncio
import codecs
import json
import logging
import os
import numpy as np
import soundfile
@ -32,34 +34,30 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if (x_len - chunk_size) % chunk_stride != 0:
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
padding = np.zeros((padding_len_x), dtype=samples.dtype)
padded_x = np.concatenate([samples, padding], axis=0)
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_stride
start = i * chunk_size
end = start + chunk_size
x_chunk = padded_x[start:end]
yield x_chunk
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# 读取音频
# self.read_wave()
# 发送 websocket 的 handshake 协议头
async with websockets.connect(self.url) as ws:
# server 端已经接收到 handshake 协议头
# 发送开始指令
audio_info = json.dumps(
{
"name": "test.wav",
@ -77,8 +75,10 @@ class ASRAudioHandler:
for chunk_data in self.read_wave(wavfile_path):
await ws.send(chunk_data.tobytes())
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,16 +91,35 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
def main(args):
logging.basicConfig(level=logging.INFO)
logging.info("asr websocket client start")
handler = ASRAudioHandler("127.0.0.1", 8091)
handler = ASRAudioHandler("127.0.0.1", 8090)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(args.wavfile))
logging.info("asr websocket client finished")
# support to process single audio file
if args.wavfile and os.path.exists(args.wavfile):
logging.info(f"start to process the wavscp: {args.wavfile}")
result = loop.run_until_complete(handler.run(args.wavfile))
result = result["asr_results"]
logging.info(f"asr websocket client finished : {result}")
# support to process batch audios from wav.scp
if args.wavscp and os.path.exists(args.wavscp):
logging.info(f"start to process the wavscp: {args.wavscp}")
with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\
codecs.open("result.txt", 'w', encoding='utf-8') as w:
for line in f:
utt_name, utt_path = line.strip().split()
result = loop.run_until_complete(handler.run(utt_path))
result = result["asr_results"]
w.write(f"{utt_name} {result}\n")
if __name__ == "__main__":
@ -110,6 +129,8 @@ if __name__ == "__main__":
action="store",
help="wav file path ",
default="./16_audio.wav")
parser.add_argument(
"--wavscp", type=str, default=None, help="The batch audios dict text")
args = parser.parse_args()
main(args)

@ -24,15 +24,38 @@ class Frame(object):
class ChunkBuffer(object):
def __init__(self,
frame_duration_ms=80,
shift_ms=40,
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=16000,
sample_width=2):
self.sample_rate = sample_rate
self.frame_duration_ms = frame_duration_ms
"""audio sample data point buffer
Args:
window_n (int, optional): decode window frame length. Defaults to 7 frame.
shift_n (int, optional): decode shift frame length. Defaults to 4 frame.
window_ms (int, optional): frame length, ms. Defaults to 20 ms.
shift_ms (int, optional): shift length, ms. Defaults to 10 ms.
sample_rate (int, optional): audio sample rate. Defaults to 16000.
sample_width (int, optional): sample point bytes. Defaults to 2 bytes.
"""
self.window_n = window_n
self.shift_n = shift_n
self.window_ms = window_ms
self.shift_ms = shift_ms
self.remained_audio = b''
self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
self.window_bytes = int(self.window_sec * self.sample_rate *
self.sample_width)
self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width)
def frame_generator(self, audio):
"""Generates audio frames from PCM audio data.
@ -43,17 +66,13 @@ class ChunkBuffer(object):
audio = self.remained_audio + audio
self.remained_audio = b''
n = int(self.sample_rate * (self.frame_duration_ms / 1000.0) *
self.sample_width)
shift_n = int(self.sample_rate * (self.shift_ms / 1000.0) *
self.sample_width)
offset = 0
timestamp = 0.0
duration = (float(n) / self.sample_rate) / self.sample_width
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
while offset + n <= len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += shift_duration
offset += shift_n
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)
timestamp += self.shift_sec
offset += self.shift_bytes
self.remained_audio += audio[offset:]

@ -36,6 +36,10 @@ async def websocket_endpoint(websocket: WebSocket):
# init buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
# init vad
@ -75,11 +79,6 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]
# vad for input bytes audio
vad.add_audio(message)
message = b''.join(f for f in vad.vad_collector()
if f is not None)
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from speechbrain(https://github.com/speechbrain/speechbrain)
"""
This script contains basic functions used for speaker diarization.
This script has an optional dependency on open source sklearn library.
@ -19,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse
import copy
import warnings
from distutils.util import strtobool
import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import linalg
from scipy import sparse
from scipy.sparse.csgraph import connected_components

@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from dataclasses import fields
from paddle.io import Dataset
from paddleaudio import load as load_audio

@ -1,7 +1,4 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
add_subdirectory(glog)
add_subdirectory(ds2_ol)
add_subdirectory(dev)

@ -1,17 +1,25 @@
# Examples
# Examples for SpeechX
* dev - for speechx developer, using for test.
* ngram - using to build NGram ARPA lm.
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
The entrypoint is `ds2_ol/aishell/run.sh`
* glog - glog usage
* feat - mfcc, linear
* nnet - ds2 nn
* decoder - online decoder to work as offline
## How to run
`run.sh` is the entry point.
Example to play `decoder`:
Example to play `ds2_ol`:
```
pushd decoder
pushd ds2_ol/aishell
bash run.sh
```
## Display Model with [Netron](https://github.com/lutzroeder/netron)
```
pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
```

@ -1,18 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -1,121 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_int32(chunk_size, 35, "feat chunk size");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test decoder by feeding speech feature, deprecated.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
int32 chunk_size = FLAGS_chunk_size;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
LOG(INFO) << "chunk size (frame): " << chunk_size;
int32 num_done = 0, num_err = 0;
// frontend + nnet is decodable
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";
// init decoder
ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts);
LOG(INFO) << "Init decoder.";
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
LOG(INFO) << "utt: " << utt;
// feat dim
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "dim: " << raw_data->Dim();
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
LOG(INFO) << "n chunks: " << num_chunks;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
// feat chunk
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(feat_one_row);
row_idx++;
}
// feed to raw cache
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
// decode step
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,43 +0,0 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
export GLOG_logtostderr=1
# 3. gen linear feat
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn
# 4. run decoder
offline_decoder_main \
--feature_respecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams \
--dict_file=$model_dir/vocab.txt \
--lm_path=$model_dir/avg_1.jit.klm

@ -0,0 +1,3 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(glog)

@ -1,14 +1,15 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_ROOT=$PWD/../../../
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
export LC_AL=C

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)

@ -0,0 +1,11 @@
# Deepspeech2 Streaming
Please go to `aishell` to test it.
* aishell
Deepspeech2 Streaming Decoding under aishell dataset.
The below is for developing and offline testing:
* nnet
* feat
* decoder

@ -0,0 +1,3 @@
data
exp
aishell_*

@ -0,0 +1,21 @@
# Aishell - Deepspeech2 Streaming
## CTC Prefix Beam Search w/o LM
```
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
```
## CTC Prefix Beam Search w LM
```
```
## CTC WFST
```
```

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -4,6 +4,9 @@ set -e
. path.sh
nj=40
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
@ -11,52 +14,59 @@ if [ ! -d ${SPEECHX_EXAMPLES} ]; then
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
mkdir -p exp
exp=$PWD/exp
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip -d $data aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
model_dir=$PWD/aishell_ds2_online_model
if [ ! -d $model_dir ]; then
mkdir -p $model_dir
wget -P $model_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $model_dir
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
if [ ! -f $lm ]; then
pushd $data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
# 3. make feature
aishell_online_model=$model_dir/exp/deepspeech2_online/checkpoints
lm_model_dir=../paddle_asr_model
label_file=./aishell_result
wer=./aishell_wer
nj=40
export GLOG_logtostderr=1
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
data=$PWD/data
# 3. gen linear feat
cmvn=$PWD/cmvn.ark
cmvn_json2binary_main --json_file=$model_dir/data/mean_std.json --cmvn_write_path=$cmvn
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat_log \
linear_spectrogram_without_db_norm_main \
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \
--cmvn_file=$cmvn \
@ -65,31 +75,33 @@ linear_spectrogram_without_db_norm_main \
text=$data/test/text
# 4. recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
offline_decoder_sliding_chunk_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
utils/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}
# 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
offline_decoder_sliding_chunk_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm
cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm
graph_dir=./aishell_graph
if [ ! -d $ ]; then
@ -97,17 +109,19 @@ if [ ! -d $ ]; then
unzip -d aishell_graph.zip
fi
# 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
offline_wfst_decoder_main \
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
utils/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg

@ -0,0 +1,19 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ctc-prefix-beam-search-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
set(bin_name wfst-decoder-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
set(bin_name nnet-logprob-decoder-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -0,0 +1,12 @@
# ASR Decoder
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
* decoder_test_main.cc
feed nnet output logprob, and only test decoder
* offline_decoder_sliding_chunk_main.cc
feed streaming audio feature, decode as streaming manner.
* offline_wfst_decoder_main.cc
feed streaming audio feature, decode using WFST as streaming manner.

@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -50,9 +52,13 @@ int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
CHECK(FLAGS_result_wspecifier != "");
CHECK(FLAGS_feature_rspecifier != "");
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
@ -73,6 +79,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/decoder:$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,79 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
# output
exp_dir=./exp
mkdir -p $exp_dir
# 2. download model
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# download lm
if [ ! -f $lm ]; then
pushd data
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
popd
fi
feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark
export GLOG_logtostderr=1
# dump json cmvn to kaldi
cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
# generate linear feature as streaming
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming
ctc-prefix-beam-search-decoder-ol \
--result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm

@ -28,6 +28,7 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,

@ -0,0 +1,12 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name linear-spectrogram-wo-db-norm-ol)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)

@ -0,0 +1,7 @@
# Deepspeech2 Straming Audio Feature
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
* linear_spectrogram_without_db_norm_main.cc
compute linear spectrogram w/o db norm in streaming manner.

@ -0,0 +1,81 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note: Do not print/log ondemand object.
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/simdjson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
try {
padded_string json = padded_string::load(FLAGS_json_file);
ondemand::parser parser;
ondemand::document doc = parser.iterate(json);
ondemand::value val = doc;
ondemand::array mean_stat = val["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) {
mean_stat_vec.push_back(x);
}
// LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated
// when
// they are first encountered")
ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
}
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num;
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
kaldi::WriteKaldiObject(
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
}
return 0;
}

@ -32,6 +32,7 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/feat
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,57 @@
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# input
data_dir=./data
exp_dir=./exp
model_dir=$data_dir/model/
mkdir -p $exp_dir
# 3. run feat
export GLOG_logtostderr=1
cmvn-json2kaldi \
--json_file $model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature."

@ -0,0 +1,6 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
set(bin_name ds2-model-ol-test)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})

@ -0,0 +1,3 @@
# Deepspeech2 Streaming NNet Test
Using for ds2 streaming nnet inference test.

@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// deepspeech2 online model info
#include <algorithm>
#include <fstream>
#include <functional>
@ -20,21 +21,26 @@
#include <iterator>
#include <numeric>
#include <thread>
#include "base/flags.h"
#include "base/log.h"
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
DEFINE_string(model_path, "", "xxx.pdmodel");
DEFINE_string(param_path, "", "xxx.pdiparams");
DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame");
DEFINE_int32(feat_dim, 161, "feature dim");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
int chunk_size = FLAGS_chunk_size; // chunk_size in frame
int col_size = FLAGS_feat_dim; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
@ -57,6 +63,8 @@ void model_forward_test() {
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
CHECK(model_graph != "");
CHECK(model_params != "");
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
@ -106,7 +114,7 @@ void model_forward_test() {
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
std::vector<int> chunk_state_h_box_shape = {5, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
@ -119,7 +127,7 @@ void model_forward_test() {
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
std::vector<int> chunk_state_c_box_shape = {5, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
@ -187,7 +195,9 @@ void model_forward_test() {
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
model_forward_test();
return 0;
}

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../..
SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/glog
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/nnet
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -0,0 +1,38 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
ckpt_dir=./data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
ds2-model-ol-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams

@ -1,18 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)

@ -1,58 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/simdjson.h"
DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace simdjson;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ondemand::parser parser;
padded_string json = padded_string::load(FLAGS_json_file);
ondemand::document val = parser.iterate(json);
ondemand::object doc = val;
kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
auto mean_stat = doc["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) {
mean_stat_vec.push_back(x);
}
auto var_stat = doc["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
}
size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
return 0;
}

@ -1,719 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -1,270 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, false);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
WriteMatrix();
int32 num_done = 0, num_err = 0;
// feature pipeline: wave cache --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FrontendInterface> data_source(
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
opt.streaming_chunk = FLAGS_streaming_chunk;
opt.frame_opts.dither = 0.0;
opt.frame_opts.remove_dc_offset = false;
opt.frame_opts.window_type = "hanning";
opt.frame_opts.preemph_coeff = 0.0;
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Vector<BaseFloat> features;
feature_cache.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished();
}
feature_cache.Read(&features);
if (features.Dim() == 0) break;
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.Dim() / feature_cache.Dim();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feature_cache.Dim());
for (auto feat : feats) {
int num_rows = feat.Dim() / feature_cache.Dim();
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
++col_idx) {
features(cur_idx, col_idx) =
feat(row_idx * feature_cache.Dim() + col_idx);
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,32 +0,0 @@
#!/bin/bash
set +x
set -e
. ./path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
feat_wspecifier=./feats.ark
cmvn=./cmvn.ark
# 3. run feat
export GLOG_logtostderr=1
linear_spectrogram_main \
--wav_rspecifier=scp:$model_dir/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_write_path=$cmvn

@ -0,0 +1,2 @@
data
exp

@ -1,5 +0,0 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})

@ -1,29 +0,0 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -d ../paddle_asr_model ]; then
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
tar xzfv paddle_asr_model.tar.gz
mv ./paddle_asr_model ../
# produce wav scp
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
fi
model_dir=../paddle_asr_model
# 4. run decoder
pp-model-test \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdparams

@ -92,8 +92,7 @@ void CTCBeamSearch::AdvanceDecode(
while (1) {
vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob;
bool flag =
decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);

@ -46,10 +46,10 @@ class LinearSpectrogram : public FrontendInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
virtual void Reset() {
base_extractor_->Reset();
reminded_wav_.Resize(0);
}
}
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,

@ -49,19 +49,19 @@ bool Decodable::IsLastFrame(int32 frame) {
int32 Decodable::NumIndices() const { return 0; }
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob id.
int32 Decodable::TokenId2NnetId(int32 token_id) {
return token_id - 1;
}
// the ilable(TokenId) of wfst(TLG) insert <eps>(id = 0) in front of Nnet prob
// id.
int32 Decodable::TokenId2NnetId(int32 token_id) { return token_id - 1; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
// the nnet output is prob ranther than log prob
// the index - 1, because the ilabel
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
// the index - 1, because the ilabel
return acoustic_scale_ *
std::log(nnet_cache_(frame_idx, TokenId2NnetId(index)) +
std::numeric_limits<float>::min());
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) {

@ -37,8 +37,7 @@ std::string ReadFile2String(const std::string& path) {
if (!input_file.is_open()) {
std::cerr << "please input a valid file" << std::endl;
}
return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>());
return std::string((std::istreambuf_iterator<char>(input_file)),
std::istreambuf_iterator<char>());
}
}

@ -20,5 +20,4 @@ bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
std::string ReadFile2String(const std::string& path);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -35,66 +35,68 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n"
"They are added on each final state and each state with non-epsilon "
"output symbols\n"
"on at least one arc out of the state. Useful in conjunction with "
"predeterminize\n"
"\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst "
"[out.fst] ]\n"
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
"same length.\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Adds self-loops to states of an FST to propagate disambiguation "
"symbols through it\n"
"They are added on each final state and each state with "
"non-epsilon "
"output symbols\n"
"on at least one arc out of the state. Useful in conjunction with "
"predeterminize\n"
"\n"
"Usage: fstaddselfloops in-disambig-list out-disambig-list "
"[in.fst "
"[out.fst] ]\n"
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
"in.list and out.list are lists of integers, one per line, of the\n"
"same length.\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
po.PrintUsage();
exit(1);
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR << "fstaddselfloops: Could not read disambiguation "
"symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR << "fstaddselfloops: mismatch in size of disambiguation "
"symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::string disambig_in_rxfilename = po.GetArg(1),
disambig_out_rxfilename = po.GetArg(2),
fst_in_filename = po.GetOptArg(3),
fst_out_filename = po.GetOptArg(4);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
std::vector<int32> disambig_out;
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
KALDI_ERR
<< "fstaddselfloops: Could not read disambiguation symbols from "
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
if (disambig_in.size() != disambig_out.size())
KALDI_ERR
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
AddSelfLoops(fst, disambig_in, disambig_out);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}

@ -56,59 +56,61 @@ bool debug_location = false;
void signal_handler(int) { debug_location = true; }
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Removes epsilons and determinizes in one step\n"
"\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n"
"See also: fstdeterminizelog, lattice-determinize\n";
const char *usage =
"Removes epsilons and determinizes in one step\n"
"\n"
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
"\n"
"See also: fstdeterminizelog, lattice-determinize\n";
float delta = kDelta;
int max_states = -1;
bool use_log = false;
ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta", &delta,
"Delta value used to determine equivalence of weights.");
po.Register(
"max-states", &max_states,
"Maximum number of states in determinized FST before it will abort.");
po.Read(argc, argv);
float delta = kDelta;
int max_states = -1;
bool use_log = false;
ParseOptions po(usage);
po.Register("use-log", &use_log, "Determinize in log semiring.");
po.Register("delta",
&delta,
"Delta value used to determine equivalence of weights.");
po.Register("max-states",
&max_states,
"Maximum number of states in determinized FST before it "
"will abort.");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
// This enables us to get traceback info from determinization that is
// not seeming to terminate.
// This enables us to get traceback info from determinization that is
// not seeming to terminate.
#if !defined(_MSC_VER) && !defined(__APPLE__)
signal(SIGUSR1, signal_handler);
signal(SIGUSR1, signal_handler);
#endif
// Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
// Normal case: just files.
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway.
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
if (use_log) {
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
} else {
VectorFst<StdArc> det_fst;
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
*fst = det_fst; // will do shallow copy and then det_fst goes
// out of scope anyway.
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
WriteFstKaldi(*fst, fst_out_str);
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -42,50 +42,51 @@
// though not stochastic because we gave it an absurdly large delta.
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Checks whether an FST is stochastic and exits with success if so.\n"
"Prints out maximum error (in log units).\n"
"\n"
"Usage: fstisstochastic [ in.fst ]\n";
const char *usage =
"Checks whether an FST is stochastic and exits with success if "
"so.\n"
"Prints out maximum error (in log units).\n"
"\n"
"Usage: fstisstochastic [ in.fst ]\n";
float delta = 0.01;
bool test_in_log = true;
float delta = 0.01;
bool test_in_log = true;
ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept.");
po.Register("test-in-log", &test_in_log,
"Test stochasticity in log semiring.");
po.Read(argc, argv);
ParseOptions po(usage);
po.Register("delta", &delta, "Maximum error to accept.");
po.Register(
"test-in-log", &test_in_log, "Test stochasticity in log semiring.");
po.Read(argc, argv);
if (po.NumArgs() > 1) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 1) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1);
std::string fst_in_filename = po.GetOptArg(1);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
bool ans;
StdArc::Weight min, max;
if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else
ans = IsStochasticFst(*fst, delta, &min, &max);
bool ans;
StdArc::Weight min, max;
if (test_in_log)
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
else
ans = IsStochasticFst(*fst, delta, &min, &max);
std::cout << min.Value() << " " << max.Value() << '\n';
delete fst;
if (ans)
return 0; // success;
else
return 1;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::cout << min.Value() << " " << max.Value() << '\n';
delete fst;
if (ans)
return 0; // success;
else
return 1;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -33,42 +33,43 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n"
"\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
const char *usage =
"Minimizes FST after encoding [similar to fstminimize, but no "
"weight-pushing]\n"
"\n"
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
float delta = kDelta;
ParseOptions po(usage);
po.Register("delta", &delta,
"Delta likelihood used for quantization of weights");
po.Read(argc, argv);
float delta = kDelta;
ParseOptions po(usage);
po.Register("delta",
&delta,
"Delta likelihood used for quantization of weights");
po.Read(argc, argv);
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
if (po.NumArgs() > 2) {
po.PrintUsage();
exit(1);
}
std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2);
std::string fst_in_filename = po.GetOptArg(1),
fst_out_filename = po.GetOptArg(2);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
MinimizeEncoded(fst, delta);
MinimizeEncoded(fst, delta);
WriteFstKaldi(*fst, fst_out_filename);
WriteFstKaldi(*fst, fst_out_filename);
delete fst;
delete fst;
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
return 0;
}

@ -37,97 +37,104 @@
*/
int main(int argc, char *argv[]) {
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
/*
fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST
is possible.
*/
const char *usage =
"Composition algorithm [between two FSTs of standard type, in "
"tropical\n"
"semiring] that is more efficient for certain cases-- in particular,\n"
"where one of the FSTs (the left one, if --match-side=left) has large\n"
"out-degree\n"
"\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n";
ParseOptions po(usage);
TableComposeOptions opts;
std::string match_side = "left";
std::string compose_filter = "sequence";
po.Register("connect", &opts.connect, "If true, trim FST before output.");
po.Register("match-side", &match_side,
"Side of composition to do table "
"match, one of: \"left\" or \"right\".");
po.Register("compose-filter", &compose_filter,
"Composition filter to use, "
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
try {
using namespace kaldi; // NOLINT
using namespace fst; // NOLINT
using kaldi::int32;
/*
fsttablecompose should always give equivalent results to compose,
but it is more efficient for certain kinds of inputs.
In particular, it is useful when, say, the left FST has states
that typically either have epsilon olabels, or
one transition out for each of the possible symbols (as the
olabel). The same with the input symbols of the right-hand FST
is possible.
*/
const char *usage =
"Composition algorithm [between two FSTs of standard type, in "
"tropical\n"
"semiring] that is more efficient for certain cases-- in "
"particular,\n"
"where one of the FSTs (the left one, if --match-side=left) has "
"large\n"
"out-degree\n"
"\n"
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
"(fst2-rxfilename|fst2-rspecifier) "
"[(out-rxfilename|out-rspecifier)]\n";
ParseOptions po(usage);
TableComposeOptions opts;
std::string match_side = "left";
std::string compose_filter = "sequence";
po.Register(
"connect", &opts.connect, "If true, trim FST before output.");
po.Register("match-side",
&match_side,
"Side of composition to do table "
"match, one of: \"left\" or \"right\".");
po.Register(
"compose-filter",
&compose_filter,
"Composition filter to use, "
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
po.Read(argc, argv);
if (match_side == "left") {
opts.table_match_type = MATCH_OUTPUT;
} else if (match_side == "right") {
opts.table_match_type = MATCH_INPUT;
} else {
KALDI_ERR << "Invalid match-side option: " << match_side;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
if (compose_filter == "alt_sequence") {
opts.filter_type = ALT_SEQUENCE_FILTER;
} else if (compose_filter == "auto") {
opts.filter_type = AUTO_FILTER;
} else if (compose_filter == "match") {
opts.filter_type = MATCH_FILTER;
} else if (compose_filter == "sequence") {
opts.filter_type = SEQUENCE_FILTER;
} else {
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
}
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
fst_out_str = po.GetOptArg(3);
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
KALDI_WARN << "The first FST is not olabel sorted.";
}
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
KALDI_WARN << "The second FST is not ilabel sorted.";
}
VectorFst<StdArc> composed_fst;
TableCompose(*fst1, *fst2, &composed_fst, opts);
delete fst1;
delete fst2;
WriteFstKaldi(composed_fst, fst_out_str);
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -24,122 +24,130 @@
#include "util/parse-options.h"
int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n"
"an embedded symbol table. This is compatible with the way a previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table", &read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table", &write_syms_filename,
"Write generated symbol table to a file");
po.Register("keep-symbols", &keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will "
"contain\n"
"an embedded symbol table. This is compatible with the way a "
"previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register(
"disambig-symbol",
&disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table",
&read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table",
&write_syms_filename,
"Write generated symbol table to a file");
po.Register(
"keep-symbols",
&keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}

@ -26,9 +26,9 @@ import argparse
import os
import re
import subprocess
from distutils.util import strtobool
import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs
Loading…
Cancel
Save