commit
9d20a10b5a
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from fastapi.testclient import TestClient
|
||||
from vpr_search import app
|
||||
|
||||
from utils.utility import download
|
||||
from utils.utility import unpack
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
Download audio data
|
||||
"""
|
||||
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
|
||||
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
|
||||
target_dir = "./"
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
unpack(filepath, target_dir, True)
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the table of MySQL
|
||||
"""
|
||||
response = client.post("/vpr/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_enroll_local(spk: str, audio: str):
|
||||
"""
|
||||
Enroll the audio to MySQL
|
||||
"""
|
||||
response = client.post("/vpr/enroll/local?spk_id=" + spk +
|
||||
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully enroll data!"
|
||||
}
|
||||
|
||||
|
||||
def test_search_local():
|
||||
"""
|
||||
Search the spk in MySQL by audio
|
||||
"""
|
||||
response = client.post(
|
||||
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_list():
|
||||
"""
|
||||
Get all records in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_data(spk: str):
|
||||
"""
|
||||
Get the audio file by spk_id in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/data?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_del(spk: str):
|
||||
"""
|
||||
Delete the record in MySQL by spk_id
|
||||
"""
|
||||
response = client.post("/vpr/del?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Get the number of spk in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/count")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_enroll_local("spk2", "sword_wielding")
|
||||
test_enroll_local("spk3", "test")
|
||||
test_list()
|
||||
test_data("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_del("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_drop()
|
@ -0,0 +1,206 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count_vpr
|
||||
from operations.count import do_get
|
||||
from operations.count import do_list
|
||||
from operations.drop import do_delete
|
||||
from operations.drop import do_drop_vpr
|
||||
from operations.load import do_enroll
|
||||
from operations.search import do_search_vpr
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.post('/vpr/enroll')
|
||||
async def vpr_enroll(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Enroll the uploaded audio with spk-id into MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} online!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/enroll/local')
|
||||
async def vpr_enroll_local(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio_path: str=None):
|
||||
# Enroll the local audio with spk-id into MySQL
|
||||
try:
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog')
|
||||
async def vpr_recog(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Voice print recognition online
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name,
|
||||
query_audio_path, MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition online!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog/local')
|
||||
async def vpr_recog_local(request: Request,
|
||||
table_name: str=None,
|
||||
audio_path: str=None):
|
||||
# Voice print recognition locally
|
||||
try:
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
|
||||
MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition locally!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/del')
|
||||
async def vpr_del(table_name: str=None, spk_id: str=None):
|
||||
# Delete a record by spk_id in MySQL
|
||||
try:
|
||||
do_delete(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info("Successfully delete a record by spk_id in MySQL")
|
||||
return {'status': True, 'msg': "Successfully delete data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/list')
|
||||
async def vpr_list(table_name: str=None):
|
||||
# Get all records in MySQL
|
||||
try:
|
||||
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
|
||||
for i in range(len(spk_ids)):
|
||||
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
|
||||
LOGGER.info("Successfully list all records from mysql!")
|
||||
return spk_ids, audio_paths
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/data')
|
||||
async def vpr_data(
|
||||
table_name: str=None,
|
||||
spk_id: str=None, ):
|
||||
# Get the audio file from path by spk_id in MySQL
|
||||
try:
|
||||
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully get audio path {audio_path}!")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/count')
|
||||
async def vpr_count(table_name: str=None):
|
||||
# Get the total number of spk in MySQL
|
||||
try:
|
||||
num = do_count_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully count the number of spk!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the table of MySQL
|
||||
try:
|
||||
do_drop_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in MySQL!")
|
||||
return {'status': True, 'msg': "Successfully drop tables!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file from path
|
||||
try:
|
||||
LOGGER.info(f"Successfully get audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"get audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
@ -0,0 +1,32 @@
|
||||
train_output_path=$1
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
|
||||
|
||||
# synthesize from metadata
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=speedyspeech_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/onnx_infer_out \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
||||
|
||||
# e2e, synthesize from text
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
python3 ${BIN_DIR}/../ort_predict_e2e.py \
|
||||
--inference_dir=${train_output_path}/inference_onnx \
|
||||
--am=speedyspeech_csmsc \
|
||||
--voc=hifigan_csmsc \
|
||||
--output_dir=${train_output_path}/onnx_infer_out_e2e \
|
||||
--text=${BIN_DIR}/../csmsc_test.txt \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--tones_dict=dump/tone_id_map.txt \
|
||||
--device=cpu \
|
||||
--cpu_threads=2
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../tts3/local/paddle2onnx.sh
|
After Width: | Height: | Size: 949 KiB |
@ -0,0 +1,18 @@
|
||||
# paddlespeech serving 网页Demo
|
||||
|
||||
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
|
||||
|
||||
|
||||
## 使用方法
|
||||
### 1. 在本地电脑启动网页服务
|
||||
```
|
||||
python app.py
|
||||
|
||||
```
|
||||
|
||||
### 2. 本地电脑浏览器
|
||||
|
||||
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
|
||||
|
||||
![图片](./paddle_web_demo.png)
|
||||
|
@ -1,7 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(feat)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(decoder)
|
||||
|
||||
add_subdirectory(glog)
|
||||
add_subdirectory(ds2_ol)
|
||||
add_subdirectory(dev)
|
@ -1,17 +1,25 @@
|
||||
# Examples
|
||||
# Examples for SpeechX
|
||||
|
||||
* dev - for speechx developer, using for test.
|
||||
* ngram - using to build NGram ARPA lm.
|
||||
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
|
||||
The entrypoint is `ds2_ol/aishell/run.sh`
|
||||
|
||||
* glog - glog usage
|
||||
* feat - mfcc, linear
|
||||
* nnet - ds2 nn
|
||||
* decoder - online decoder to work as offline
|
||||
|
||||
## How to run
|
||||
|
||||
`run.sh` is the entry point.
|
||||
|
||||
Example to play `decoder`:
|
||||
Example to play `ds2_ol`:
|
||||
|
||||
```
|
||||
pushd decoder
|
||||
pushd ds2_ol/aishell
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## Display Model with [Netron](https://github.com/lutzroeder/netron)
|
||||
|
||||
```
|
||||
pip install netron
|
||||
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
|
||||
```
|
||||
|
@ -1 +0,0 @@
|
||||
../../../utils
|
@ -1,18 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
|
||||
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
|
||||
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
|
||||
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
||||
|
||||
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
|
||||
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
@ -1,121 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
||||
DEFINE_int32(chunk_size, 35, "feat chunk size");
|
||||
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test decoder by feeding speech feature, deprecated.
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_respecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
int32 chunk_size = FLAGS_chunk_size;
|
||||
LOG(INFO) << "model path: " << model_graph;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
// frontend + nnet is decodable
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_graph;
|
||||
model_opts.params_path = model_params;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data));
|
||||
LOG(INFO) << "Init decodeable.";
|
||||
|
||||
// init decoder
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
LOG(INFO) << "Init decoder.";
|
||||
|
||||
decoder.InitDecoder();
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
|
||||
// feat dim
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "dim: " << raw_data->Dim();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 num_chunks = feature.NumRows() / chunk_size;
|
||||
LOG(INFO) << "n chunks: " << num_chunks;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
// feat chunk
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
|
||||
row_idx);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(feat_one_row);
|
||||
row_idx++;
|
||||
}
|
||||
// feed to raw cache
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
// decode step
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# 3. gen linear feat
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
||||
|
||||
# 4. run decoder
|
||||
offline_decoder_main \
|
||||
--feature_respecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
||||
--dict_file=$model_dir/vocab.txt \
|
||||
--lm_path=$model_dir/avg_1.jit.klm
|
@ -0,0 +1,3 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(glog)
|
@ -1,14 +1,15 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
||||
|
||||
export LC_AL=C
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(feat)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(decoder)
|
@ -0,0 +1,11 @@
|
||||
# Deepspeech2 Streaming
|
||||
|
||||
Please go to `aishell` to test it.
|
||||
|
||||
* aishell
|
||||
Deepspeech2 Streaming Decoding under aishell dataset.
|
||||
|
||||
The below is for developing and offline testing:
|
||||
* nnet
|
||||
* feat
|
||||
* decoder
|
@ -0,0 +1,3 @@
|
||||
data
|
||||
exp
|
||||
aishell_*
|
@ -0,0 +1,21 @@
|
||||
# Aishell - Deepspeech2 Streaming
|
||||
|
||||
## CTC Prefix Beam Search w/o LM
|
||||
|
||||
```
|
||||
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
||||
```
|
||||
|
||||
## CTC Prefix Beam Search w LM
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
## CTC WFST
|
||||
|
||||
```
|
||||
|
||||
```
|
@ -0,0 +1 @@
|
||||
../../../../utils/
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,19 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ctc-prefix-beam-search-decoder-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
|
||||
set(bin_name wfst-decoder-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
||||
|
||||
|
||||
set(bin_name nnet-logprob-decoder-test)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
@ -0,0 +1,12 @@
|
||||
# ASR Decoder
|
||||
|
||||
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
|
||||
|
||||
* decoder_test_main.cc
|
||||
feed nnet output logprob, and only test decoder
|
||||
|
||||
* offline_decoder_sliding_chunk_main.cc
|
||||
feed streaming audio feature, decode as streaming manner.
|
||||
|
||||
* offline_wfst_decoder_main.cc
|
||||
feed streaming audio feature, decode using WFST as streaming manner.
|
@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
# output
|
||||
exp_dir=./exp
|
||||
mkdir -p $exp_dir
|
||||
|
||||
# 2. download model
|
||||
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
# download lm
|
||||
if [ ! -f $lm ]; then
|
||||
pushd data
|
||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
feat_wspecifier=$exp_dir/feats.ark
|
||||
cmvn=$exp_dir/cmvn.ark
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# dump json cmvn to kaldi
|
||||
cmvn-json2kaldi \
|
||||
--json_file $ckpt_dir/data/mean_std.json \
|
||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
# generate linear feature as streaming
|
||||
linear-spectrogram-wo-db-norm-ol \
|
||||
--wav_rspecifier=scp:$data/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_file=$exp_dir/cmvn.ark
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
# run ctc beam search decoder as streaming
|
||||
ctc-prefix-beam-search-decoder-ol \
|
||||
--result_wspecifier=ark,t:$exp_dir/result.txt \
|
||||
--feature_rspecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--lm_path=$lm
|
@ -0,0 +1,2 @@
|
||||
exp
|
||||
data
|
@ -0,0 +1,12 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name linear-spectrogram-wo-db-norm-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
|
||||
set(bin_name cmvn-json2kaldi)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
|
@ -0,0 +1,7 @@
|
||||
# Deepspeech2 Straming Audio Feature
|
||||
|
||||
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
|
||||
|
||||
* linear_spectrogram_without_db_norm_main.cc
|
||||
|
||||
compute linear spectrogram w/o db norm in streaming manner.
|
@ -0,0 +1,81 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Note: Do not print/log ondemand object.
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/simdjson.h"
|
||||
|
||||
DEFINE_string(json_file, "", "cmvn json file");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
||||
|
||||
using namespace simdjson;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
|
||||
|
||||
try {
|
||||
padded_string json = padded_string::load(FLAGS_json_file);
|
||||
|
||||
ondemand::parser parser;
|
||||
ondemand::document doc = parser.iterate(json);
|
||||
ondemand::value val = doc;
|
||||
|
||||
ondemand::array mean_stat = val["mean_stat"];
|
||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
||||
for (double x : mean_stat) {
|
||||
mean_stat_vec.push_back(x);
|
||||
}
|
||||
// LOG(INFO) << mean_stat; this line will casue
|
||||
// simdjson::simdjson_error("Objects and arrays can only be iterated
|
||||
// when
|
||||
// they are first encountered")
|
||||
|
||||
ondemand::array var_stat = val["var_stat"];
|
||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
||||
for (double x : var_stat) {
|
||||
var_stat_vec.push_back(x);
|
||||
}
|
||||
|
||||
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
|
||||
LOG(INFO) << "nframe: " << frame_num;
|
||||
|
||||
size_t mean_size = mean_stat_vec.size();
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_size) = frame_num;
|
||||
LOG(INFO) << cmvn_stats;
|
||||
|
||||
kaldi::WriteKaldiObject(
|
||||
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
||||
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
|
||||
LOG(INFO) << "Binary: " << FLAGS_binary;
|
||||
} catch (simdjson::simdjson_error& err) {
|
||||
LOG(ERR) << err.what();
|
||||
}
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# input
|
||||
data_dir=./data
|
||||
exp_dir=./exp
|
||||
model_dir=$data_dir/model/
|
||||
|
||||
mkdir -p $exp_dir
|
||||
|
||||
|
||||
# 3. run feat
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
cmvn-json2kaldi \
|
||||
--json_file $model_dir/data/mean_std.json \
|
||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
linear-spectrogram-wo-db-norm-ol \
|
||||
--wav_rspecifier=scp:$data_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
|
||||
--cmvn_file=$exp_dir/cmvn.ark
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ds2-model-ol-test)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
@ -0,0 +1,3 @@
|
||||
# Deepspeech2 Streaming NNet Test
|
||||
|
||||
Using for ds2 streaming nnet inference test.
|
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
|
||||
ds2-model-ol-test \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams
|
||||
|
@ -1,18 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
|
||||
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
|
||||
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(mfcc-test kaldi-mfcc)
|
||||
|
||||
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
|
||||
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
|
||||
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
|
||||
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)
|
@ -1,58 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/simdjson.h"
|
||||
|
||||
DEFINE_string(json_file, "", "cmvn json file");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
||||
|
||||
using namespace simdjson;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
ondemand::parser parser;
|
||||
padded_string json = padded_string::load(FLAGS_json_file);
|
||||
ondemand::document val = parser.iterate(json);
|
||||
ondemand::object doc = val;
|
||||
kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
|
||||
auto mean_stat = doc["mean_stat"];
|
||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
||||
for (double x : mean_stat) {
|
||||
mean_stat_vec.push_back(x);
|
||||
}
|
||||
auto var_stat = doc["var_stat"];
|
||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
||||
for (double x : var_stat) {
|
||||
var_stat_vec.push_back(x);
|
||||
}
|
||||
|
||||
size_t mean_size = mean_stat_vec.size();
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_size) = frame_num;
|
||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
||||
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
|
||||
return 0;
|
||||
}
|
@ -1,719 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// feat/feature-mfcc-test.cc
|
||||
|
||||
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "base/kaldi-math.h"
|
||||
#include "feat/feature-mfcc.h"
|
||||
#include "feat/wave-reader.h"
|
||||
#include "matrix/kaldi-matrix-inl.h"
|
||||
|
||||
using namespace kaldi;
|
||||
|
||||
static void UnitTestReadWave() {
|
||||
std::cout << "=== UnitTestReadWave() ===\n";
|
||||
|
||||
Vector<BaseFloat> v, v2;
|
||||
|
||||
std::cout << "<<<=== Reading waveform\n";
|
||||
|
||||
{
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
const Matrix<BaseFloat> data(wave.Data());
|
||||
KALDI_ASSERT(data.NumRows() == 1);
|
||||
v.Resize(data.NumCols());
|
||||
v.CopyFromVec(data.Row(0));
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
|
||||
std::ifstream input("test_data/test_matlab.ascii");
|
||||
KALDI_ASSERT(input.good());
|
||||
v2.Read(input, false);
|
||||
input.close();
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
|
||||
KALDI_ASSERT(v.Dim() == v2.Dim());
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
KALDI_ASSERT(v(i) == v2(i));
|
||||
}
|
||||
std::cout << "<<<=== Comparing done\n";
|
||||
|
||||
// std::cout << "== The Waveform Samples == \n";
|
||||
// std::cout << v;
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*/
|
||||
static void UnitTestSimple() {
|
||||
std::cout << "=== UnitTestSimple() ===\n";
|
||||
|
||||
Vector<BaseFloat> v(100000);
|
||||
Matrix<BaseFloat> m;
|
||||
|
||||
// init with noise
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
|
||||
}
|
||||
|
||||
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
|
||||
// the parametrization object
|
||||
MfccOptions op;
|
||||
// trying to have same opts as baseline.
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "rectangular";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
// use default parameters
|
||||
|
||||
// compute mfccs.
|
||||
mfcc.Compute(v, 1.0, &m);
|
||||
|
||||
// possibly dump
|
||||
// std::cout << "== Output features == \n" << m;
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare1() {
|
||||
std::cout << "=== UnitTestHTKCompare1() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.1",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.1",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.1");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare2() {
|
||||
std::cout << "=== UnitTestHTKCompare2() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.2",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.2",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.2");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare3() {
|
||||
std::cout << "=== UnitTestHTKCompare3() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.3",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 20.0;
|
||||
// op.mel_opts.debug_mel = true;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.3",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.3");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare4() {
|
||||
std::cout << "=== UnitTestHTKCompare4() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.4",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.4",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.4");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare5() {
|
||||
std::cout << "=== UnitTestHTKCompare5() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.5",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.vtln_low = 100.0;
|
||||
op.mel_opts.vtln_high = 7500.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
BaseFloat vtln_warp =
|
||||
1.1; // our approach identical to htk for warp factor >1,
|
||||
// differs slightly for higher mel bins if warp_factor <0.9
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.5",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.5");
|
||||
}
|
||||
|
||||
static void UnitTestHTKCompare6() {
|
||||
std::cout << "=== UnitTestHTKCompare6() ===\n";
|
||||
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.6",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.97;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.num_bins = 24;
|
||||
op.mel_opts.low_freq = 125.0;
|
||||
op.mel_opts.high_freq = 7800.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.6",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.6");
|
||||
}
|
||||
|
||||
void UnitTestVtln() {
|
||||
// Test the function VtlnWarpFreq.
|
||||
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
|
||||
vtln_high_cutoff = 7400;
|
||||
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq),
|
||||
freq / warp_factor);
|
||||
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
low_freq),
|
||||
low_freq);
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
high_freq),
|
||||
high_freq);
|
||||
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
|
||||
freq3 = freq2 +
|
||||
(high_freq - freq2) * RandUniform(); // freq3>=freq2
|
||||
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq2);
|
||||
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq3);
|
||||
KALDI_ASSERT(w3 >= w2); // increasing function.
|
||||
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
|
||||
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
|
||||
AssertEqual(w3dash, freq3);
|
||||
}
|
||||
}
|
||||
|
||||
static void UnitTestFeat() {
|
||||
UnitTestVtln();
|
||||
UnitTestReadWave();
|
||||
UnitTestSimple();
|
||||
UnitTestHTKCompare1();
|
||||
UnitTestHTKCompare2();
|
||||
// commenting out this one as it doesn't compare right now I normalized
|
||||
// the way the FFT bins are treated (removed offset of 0.5)... this seems
|
||||
// to relate to the way frequency zero behaves.
|
||||
UnitTestHTKCompare3();
|
||||
UnitTestHTKCompare4();
|
||||
UnitTestHTKCompare5();
|
||||
UnitTestHTKCompare6();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
try {
|
||||
for (int i = 0; i < 5; i++) UnitTestFeat();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return 1;
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue