commit
159d8fd628
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from fastapi.testclient import TestClient
|
||||
from vpr_search import app
|
||||
|
||||
from utils.utility import download
|
||||
from utils.utility import unpack
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def download_audio_data():
|
||||
"""
|
||||
Download audio data
|
||||
"""
|
||||
url = "https://paddlespeech.bj.bcebos.com/vector/audio/example_audio.tar.gz"
|
||||
md5sum = "52ac69316c1aa1fdef84da7dd2c67b39"
|
||||
target_dir = "./"
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
unpack(filepath, target_dir, True)
|
||||
|
||||
|
||||
def test_drop():
|
||||
"""
|
||||
Delete the table of MySQL
|
||||
"""
|
||||
response = client.post("/vpr/drop")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_enroll_local(spk: str, audio: str):
|
||||
"""
|
||||
Enroll the audio to MySQL
|
||||
"""
|
||||
response = client.post("/vpr/enroll/local?spk_id=" + spk +
|
||||
"&audio_path=.%2Fexample_audio%2F" + audio + ".wav")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
'status': True,
|
||||
'msg': "Successfully enroll data!"
|
||||
}
|
||||
|
||||
|
||||
def test_search_local():
|
||||
"""
|
||||
Search the spk in MySQL by audio
|
||||
"""
|
||||
response = client.post(
|
||||
"/vpr/recog/local?audio_path=.%2Fexample_audio%2Ftest.wav")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_list():
|
||||
"""
|
||||
Get all records in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/list")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_data(spk: str):
|
||||
"""
|
||||
Get the audio file by spk_id in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/data?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_del(spk: str):
|
||||
"""
|
||||
Delete the record in MySQL by spk_id
|
||||
"""
|
||||
response = client.post("/vpr/del?spk_id=" + spk)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_count():
|
||||
"""
|
||||
Get the number of spk in MySQL
|
||||
"""
|
||||
response = client.get("/vpr/count")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_audio_data()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_enroll_local("spk2", "sword_wielding")
|
||||
test_enroll_local("spk3", "test")
|
||||
test_list()
|
||||
test_data("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_del("spk1")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_enroll_local("spk1", "arms_strikes")
|
||||
test_count()
|
||||
test_search_local()
|
||||
|
||||
test_drop()
|
@ -0,0 +1,206 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from config import UPLOAD_PATH
|
||||
from fastapi import FastAPI
|
||||
from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from logs import LOGGER
|
||||
from mysql_helpers import MySQLHelper
|
||||
from operations.count import do_count_vpr
|
||||
from operations.count import do_get
|
||||
from operations.count import do_list
|
||||
from operations.drop import do_delete
|
||||
from operations.drop import do_drop_vpr
|
||||
from operations.load import do_enroll
|
||||
from operations.search import do_search_vpr
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"])
|
||||
|
||||
MYSQL_CLI = MySQLHelper()
|
||||
|
||||
# Mkdir 'tmp/audio-data'
|
||||
if not os.path.exists(UPLOAD_PATH):
|
||||
os.makedirs(UPLOAD_PATH)
|
||||
LOGGER.info(f"Mkdir the path: {UPLOAD_PATH}")
|
||||
|
||||
|
||||
@app.post('/vpr/enroll')
|
||||
async def vpr_enroll(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Enroll the uploaded audio with spk-id into MySQL
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} online!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/enroll/local')
|
||||
async def vpr_enroll_local(table_name: str=None,
|
||||
spk_id: str=None,
|
||||
audio_path: str=None):
|
||||
# Enroll the local audio with spk-id into MySQL
|
||||
try:
|
||||
do_enroll(table_name, spk_id, audio_path, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully enrolled {spk_id} locally!")
|
||||
return {'status': True, 'msg': "Successfully enroll data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog')
|
||||
async def vpr_recog(request: Request,
|
||||
table_name: str=None,
|
||||
audio: UploadFile=File(...)):
|
||||
# Voice print recognition online
|
||||
try:
|
||||
# Save the upload data to server.
|
||||
content = await audio.read()
|
||||
query_audio_path = os.path.join(UPLOAD_PATH, audio.filename)
|
||||
with open(query_audio_path, "wb+") as f:
|
||||
f.write(content)
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name,
|
||||
query_audio_path, MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition online!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/recog/local')
|
||||
async def vpr_recog_local(request: Request,
|
||||
table_name: str=None,
|
||||
audio_path: str=None):
|
||||
# Voice print recognition locally
|
||||
try:
|
||||
host = request.headers['host']
|
||||
spk_ids, paths, scores = do_search_vpr(host, table_name, audio_path,
|
||||
MYSQL_CLI)
|
||||
for spk_id, path, score in zip(spk_ids, paths, scores):
|
||||
LOGGER.info(f"spk {spk_id}, score {score}, audio path {path}, ")
|
||||
res = dict(zip(spk_ids, zip(paths, scores)))
|
||||
# Sort results by distance metric, closest distances first
|
||||
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
|
||||
LOGGER.info("Successfully speaker recognition locally!")
|
||||
return res
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/del')
|
||||
async def vpr_del(table_name: str=None, spk_id: str=None):
|
||||
# Delete a record by spk_id in MySQL
|
||||
try:
|
||||
do_delete(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info("Successfully delete a record by spk_id in MySQL")
|
||||
return {'status': True, 'msg': "Successfully delete data!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/list')
|
||||
async def vpr_list(table_name: str=None):
|
||||
# Get all records in MySQL
|
||||
try:
|
||||
spk_ids, audio_paths = do_list(table_name, MYSQL_CLI)
|
||||
for i in range(len(spk_ids)):
|
||||
LOGGER.debug(f"spk {spk_ids[i]}, audio path {audio_paths[i]}")
|
||||
LOGGER.info("Successfully list all records from mysql!")
|
||||
return spk_ids, audio_paths
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/data')
|
||||
async def vpr_data(
|
||||
table_name: str=None,
|
||||
spk_id: str=None, ):
|
||||
# Get the audio file from path by spk_id in MySQL
|
||||
try:
|
||||
audio_path = do_get(table_name, spk_id, MYSQL_CLI)
|
||||
LOGGER.info(f"Successfully get audio path {audio_path}!")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/vpr/count')
|
||||
async def vpr_count(table_name: str=None):
|
||||
# Get the total number of spk in MySQL
|
||||
try:
|
||||
num = do_count_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully count the number of spk!")
|
||||
return num
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.post('/vpr/drop')
|
||||
async def drop_tables(table_name: str=None):
|
||||
# Delete the table of MySQL
|
||||
try:
|
||||
do_drop_vpr(table_name, MYSQL_CLI)
|
||||
LOGGER.info("Successfully drop tables in MySQL!")
|
||||
return {'status': True, 'msg': "Successfully drop tables!"}
|
||||
except Exception as e:
|
||||
LOGGER.error(e)
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
@app.get('/data')
|
||||
def audio_path(audio_path):
|
||||
# Get the audio file from path
|
||||
try:
|
||||
LOGGER.info(f"Successfully get audio: {audio_path}")
|
||||
return FileResponse(audio_path)
|
||||
except Exception as e:
|
||||
LOGGER.error(f"get audio error: {e}")
|
||||
return {'status': False, 'msg': e}, 400
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8002)
|
After Width: | Height: | Size: 949 KiB |
@ -0,0 +1,18 @@
|
||||
# paddlespeech serving 网页Demo
|
||||
|
||||
- 感谢[wenet](https://github.com/wenet-e2e/wenet)团队的前端demo代码.
|
||||
|
||||
|
||||
## 使用方法
|
||||
### 1. 在本地电脑启动网页服务
|
||||
```
|
||||
python app.py
|
||||
|
||||
```
|
||||
|
||||
### 2. 本地电脑浏览器
|
||||
|
||||
在浏览器中输入127.0.0.1:19999 即可看到相关网页Demo。
|
||||
|
||||
![图片](./paddle_web_demo.png)
|
||||
|
@ -1,7 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(feat)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(decoder)
|
||||
|
||||
add_subdirectory(glog)
|
||||
add_subdirectory(ds2_ol)
|
||||
add_subdirectory(dev)
|
@ -1,17 +1,25 @@
|
||||
# Examples
|
||||
# Examples for SpeechX
|
||||
|
||||
* dev - for speechx developer, using for test.
|
||||
* ngram - using to build NGram ARPA lm.
|
||||
* ds2_ol - ds2 streaming test under `aishell-1` test dataset.
|
||||
The entrypoint is `ds2_ol/aishell/run.sh`
|
||||
|
||||
* glog - glog usage
|
||||
* feat - mfcc, linear
|
||||
* nnet - ds2 nn
|
||||
* decoder - online decoder to work as offline
|
||||
|
||||
## How to run
|
||||
|
||||
`run.sh` is the entry point.
|
||||
|
||||
Example to play `decoder`:
|
||||
Example to play `ds2_ol`:
|
||||
|
||||
```
|
||||
pushd decoder
|
||||
pushd ds2_ol/aishell
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
## Display Model with [Netron](https://github.com/lutzroeder/netron)
|
||||
|
||||
```
|
||||
pip install netron
|
||||
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20
|
||||
```
|
||||
|
@ -1 +0,0 @@
|
||||
../../../utils
|
@ -1,18 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(offline_decoder_sliding_chunk_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_sliding_chunk_main.cc)
|
||||
target_include_directories(offline_decoder_sliding_chunk_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_decoder_sliding_chunk_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
|
||||
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
|
||||
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
||||
|
||||
add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
|
||||
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
@ -1,121 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_beam_search_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_respecifier, "", "feature matrix rspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
|
||||
DEFINE_string(lm_path, "lm.klm", "language model");
|
||||
DEFINE_int32(chunk_size, 35, "feat chunk size");
|
||||
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test decoder by feeding speech feature, deprecated.
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_respecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string dict_file = FLAGS_dict_file;
|
||||
std::string lm_path = FLAGS_lm_path;
|
||||
int32 chunk_size = FLAGS_chunk_size;
|
||||
LOG(INFO) << "model path: " << model_graph;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "dict path: " << dict_file;
|
||||
LOG(INFO) << "lm path: " << lm_path;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
// frontend + nnet is decodable
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_graph;
|
||||
model_opts.params_path = model_params;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data));
|
||||
LOG(INFO) << "Init decodeable.";
|
||||
|
||||
// init decoder
|
||||
ppspeech::CTCBeamSearchOptions opts;
|
||||
opts.dict_file = dict_file;
|
||||
opts.lm_path = lm_path;
|
||||
ppspeech::CTCBeamSearch decoder(opts);
|
||||
LOG(INFO) << "Init decoder.";
|
||||
|
||||
decoder.InitDecoder();
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
LOG(INFO) << "utt: " << utt;
|
||||
|
||||
// feat dim
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "dim: " << raw_data->Dim();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 num_chunks = feature.NumRows() / chunk_size;
|
||||
LOG(INFO) << "n chunks: " << num_chunks;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
// feat chunk
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> feat_one_row(feature,
|
||||
row_idx);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(feat_one_row);
|
||||
row_idx++;
|
||||
}
|
||||
// feed to raw cache
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
// decode step
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# 3. gen linear feat
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
||||
|
||||
# 4. run decoder
|
||||
offline_decoder_main \
|
||||
--feature_respecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams \
|
||||
--dict_file=$model_dir/vocab.txt \
|
||||
--lm_path=$model_dir/avg_1.jit.klm
|
@ -0,0 +1,3 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(glog)
|
@ -1,14 +1,15 @@
|
||||
# This contains the locations of binarys build required for running the examples.
|
||||
|
||||
SPEECHX_ROOT=$PWD/../..
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
SPEECHX_ROOT=$PWD/../../../
|
||||
|
||||
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
|
||||
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
|
||||
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
export LC_AL=C
|
||||
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
|
||||
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; }
|
||||
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/nnet
|
||||
SPEECHX_BIN=$SPEECHX_EXAMPLES/dev/glog
|
||||
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
|
||||
|
||||
export LC_AL=C
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_subdirectory(feat)
|
||||
add_subdirectory(nnet)
|
||||
add_subdirectory(decoder)
|
@ -0,0 +1,11 @@
|
||||
# Deepspeech2 Streaming
|
||||
|
||||
Please go to `aishell` to test it.
|
||||
|
||||
* aishell
|
||||
Deepspeech2 Streaming Decoding under aishell dataset.
|
||||
|
||||
The below is for developing and offline testing:
|
||||
* nnet
|
||||
* feat
|
||||
* decoder
|
@ -0,0 +1,3 @@
|
||||
data
|
||||
exp
|
||||
aishell_*
|
@ -0,0 +1,21 @@
|
||||
# Aishell - Deepspeech2 Streaming
|
||||
|
||||
## CTC Prefix Beam Search w/o LM
|
||||
|
||||
```
|
||||
Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465
|
||||
Other -> 0.00 % N=0 C=0 S=0 D=0 I=0
|
||||
```
|
||||
|
||||
## CTC Prefix Beam Search w LM
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
## CTC WFST
|
||||
|
||||
```
|
||||
|
||||
```
|
@ -0,0 +1 @@
|
||||
../../../../utils/
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,19 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ctc-prefix-beam-search-decoder-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
||||
|
||||
set(bin_name wfst-decoder-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
|
||||
|
||||
|
||||
set(bin_name nnet-logprob-decoder-test)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
|
||||
|
@ -0,0 +1,12 @@
|
||||
# ASR Decoder
|
||||
|
||||
ASR Decoder test bins. We using theses bins to test CTC BeamSearch decoder and WFST decoder.
|
||||
|
||||
* decoder_test_main.cc
|
||||
feed nnet output logprob, and only test decoder
|
||||
|
||||
* offline_decoder_sliding_chunk_main.cc
|
||||
feed streaming audio feature, decode as streaming manner.
|
||||
|
||||
* offline_wfst_decoder_main.cc
|
||||
feed streaming audio feature, decode using WFST as streaming manner.
|
@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# input
|
||||
mkdir -p data
|
||||
data=$PWD/data
|
||||
ckpt_dir=$data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
vocb_dir=$ckpt_dir/data/lang_char/
|
||||
|
||||
lm=$data/zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
# output
|
||||
exp_dir=./exp
|
||||
mkdir -p $exp_dir
|
||||
|
||||
# 2. download model
|
||||
if [[ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
# download lm
|
||||
if [ ! -f $lm ]; then
|
||||
pushd data
|
||||
wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
feat_wspecifier=$exp_dir/feats.ark
|
||||
cmvn=$exp_dir/cmvn.ark
|
||||
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
# dump json cmvn to kaldi
|
||||
cmvn-json2kaldi \
|
||||
--json_file $ckpt_dir/data/mean_std.json \
|
||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
# generate linear feature as streaming
|
||||
linear-spectrogram-wo-db-norm-ol \
|
||||
--wav_rspecifier=scp:$data/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_file=$exp_dir/cmvn.ark
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
# run ctc beam search decoder as streaming
|
||||
ctc-prefix-beam-search-decoder-ol \
|
||||
--result_wspecifier=ark,t:$exp_dir/result.txt \
|
||||
--feature_rspecifier=ark:$feat_wspecifier \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams \
|
||||
--dict_file=$vocb_dir/vocab.txt \
|
||||
--lm_path=$lm
|
@ -0,0 +1,2 @@
|
||||
exp
|
||||
data
|
@ -0,0 +1,12 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name linear-spectrogram-wo-db-norm-ol)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
|
||||
set(bin_name cmvn-json2kaldi)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
|
@ -0,0 +1,7 @@
|
||||
# Deepspeech2 Straming Audio Feature
|
||||
|
||||
ASR audio feature test bins. We using theses bins to test linaer/fbank/mfcc asr feature as streaming manner.
|
||||
|
||||
* linear_spectrogram_without_db_norm_main.cc
|
||||
|
||||
compute linear spectrogram w/o db norm in streaming manner.
|
@ -0,0 +1,81 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Note: Do not print/log ondemand object.
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/simdjson.h"
|
||||
|
||||
DEFINE_string(json_file, "", "cmvn json file");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
||||
|
||||
using namespace simdjson;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
|
||||
|
||||
try {
|
||||
padded_string json = padded_string::load(FLAGS_json_file);
|
||||
|
||||
ondemand::parser parser;
|
||||
ondemand::document doc = parser.iterate(json);
|
||||
ondemand::value val = doc;
|
||||
|
||||
ondemand::array mean_stat = val["mean_stat"];
|
||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
||||
for (double x : mean_stat) {
|
||||
mean_stat_vec.push_back(x);
|
||||
}
|
||||
// LOG(INFO) << mean_stat; this line will casue
|
||||
// simdjson::simdjson_error("Objects and arrays can only be iterated
|
||||
// when
|
||||
// they are first encountered")
|
||||
|
||||
ondemand::array var_stat = val["var_stat"];
|
||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
||||
for (double x : var_stat) {
|
||||
var_stat_vec.push_back(x);
|
||||
}
|
||||
|
||||
kaldi::int32 frame_num = uint64_t(val["frame_num"]);
|
||||
LOG(INFO) << "nframe: " << frame_num;
|
||||
|
||||
size_t mean_size = mean_stat_vec.size();
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_size) = frame_num;
|
||||
LOG(INFO) << cmvn_stats;
|
||||
|
||||
kaldi::WriteKaldiObject(
|
||||
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
||||
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
|
||||
LOG(INFO) << "Binary: " << FLAGS_binary;
|
||||
} catch (simdjson::simdjson_error& err) {
|
||||
LOG(ERR) << err.what();
|
||||
}
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -e data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
|
||||
# input
|
||||
data_dir=./data
|
||||
exp_dir=./exp
|
||||
model_dir=$data_dir/model/
|
||||
|
||||
mkdir -p $exp_dir
|
||||
|
||||
|
||||
# 3. run feat
|
||||
export GLOG_logtostderr=1
|
||||
|
||||
cmvn-json2kaldi \
|
||||
--json_file $model_dir/data/mean_std.json \
|
||||
--cmvn_write_path $exp_dir/cmvn.ark \
|
||||
--binary=false
|
||||
echo "convert json cmvn to kaldi ark."
|
||||
|
||||
|
||||
linear-spectrogram-wo-db-norm-ol \
|
||||
--wav_rspecifier=scp:$data_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$exp_dir/feats.ark \
|
||||
--cmvn_file=$exp_dir/cmvn.ark
|
||||
echo "compute linear spectrogram feature."
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
set(bin_name ds2-model-ol-test)
|
||||
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
|
||||
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS})
|
@ -0,0 +1,3 @@
|
||||
# Deepspeech2 Streaming NNet Test
|
||||
|
||||
Using for ds2 streaming nnet inference test.
|
@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -f data/model/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz ]; then
|
||||
mkdir -p data/model
|
||||
pushd data/model
|
||||
wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
|
||||
popd
|
||||
fi
|
||||
|
||||
# produce wav scp
|
||||
if [ ! -f data/wav.scp ]; then
|
||||
mkdir -p data
|
||||
pushd data
|
||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||
echo "utt1 " $PWD/zh.wav > wav.scp
|
||||
popd
|
||||
fi
|
||||
|
||||
ckpt_dir=./data/model
|
||||
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
|
||||
|
||||
ds2-model-ol-test \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdiparams
|
||||
|
@ -1,18 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
|
||||
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
|
||||
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(mfcc-test kaldi-mfcc)
|
||||
|
||||
add_executable(linear_spectrogram_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_main.cc)
|
||||
target_include_directories(linear_spectrogram_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
add_executable(linear_spectrogram_without_db_norm_main ${CMAKE_CURRENT_SOURCE_DIR}/linear_spectrogram_without_db_norm_main.cc)
|
||||
target_include_directories(linear_spectrogram_without_db_norm_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(linear_spectrogram_without_db_norm_main frontend kaldi-util kaldi-feat-common gflags glog)
|
||||
|
||||
add_executable(cmvn_json2binary_main ${CMAKE_CURRENT_SOURCE_DIR}/cmvn_json2binary_main.cc)
|
||||
target_include_directories(cmvn_json2binary_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(cmvn_json2binary_main utils kaldi-util kaldi-matrix gflags glog)
|
@ -1,58 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/matrix/kaldi-matrix.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/simdjson.h"
|
||||
|
||||
DEFINE_string(json_file, "", "cmvn json file");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
|
||||
|
||||
using namespace simdjson;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
ondemand::parser parser;
|
||||
padded_string json = padded_string::load(FLAGS_json_file);
|
||||
ondemand::document val = parser.iterate(json);
|
||||
ondemand::object doc = val;
|
||||
kaldi::int32 frame_num = uint64_t(doc["frame_num"]);
|
||||
auto mean_stat = doc["mean_stat"];
|
||||
std::vector<kaldi::BaseFloat> mean_stat_vec;
|
||||
for (double x : mean_stat) {
|
||||
mean_stat_vec.push_back(x);
|
||||
}
|
||||
auto var_stat = doc["var_stat"];
|
||||
std::vector<kaldi::BaseFloat> var_stat_vec;
|
||||
for (double x : var_stat) {
|
||||
var_stat_vec.push_back(x);
|
||||
}
|
||||
|
||||
size_t mean_size = mean_stat_vec.size();
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
|
||||
for (size_t idx = 0; idx < mean_size; ++idx) {
|
||||
cmvn_stats(0, idx) = mean_stat_vec[idx];
|
||||
cmvn_stats(1, idx) = var_stat_vec[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_size) = frame_num;
|
||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
|
||||
LOG(INFO) << "the json file have write into " << FLAGS_cmvn_write_path;
|
||||
return 0;
|
||||
}
|
@ -1,719 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// feat/feature-mfcc-test.cc
|
||||
|
||||
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "base/kaldi-math.h"
|
||||
#include "feat/feature-mfcc.h"
|
||||
#include "feat/wave-reader.h"
|
||||
#include "matrix/kaldi-matrix-inl.h"
|
||||
|
||||
using namespace kaldi;
|
||||
|
||||
static void UnitTestReadWave() {
|
||||
std::cout << "=== UnitTestReadWave() ===\n";
|
||||
|
||||
Vector<BaseFloat> v, v2;
|
||||
|
||||
std::cout << "<<<=== Reading waveform\n";
|
||||
|
||||
{
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
const Matrix<BaseFloat> data(wave.Data());
|
||||
KALDI_ASSERT(data.NumRows() == 1);
|
||||
v.Resize(data.NumCols());
|
||||
v.CopyFromVec(data.Row(0));
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
|
||||
std::ifstream input("test_data/test_matlab.ascii");
|
||||
KALDI_ASSERT(input.good());
|
||||
v2.Read(input, false);
|
||||
input.close();
|
||||
|
||||
std::cout
|
||||
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
|
||||
KALDI_ASSERT(v.Dim() == v2.Dim());
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
KALDI_ASSERT(v(i) == v2(i));
|
||||
}
|
||||
std::cout << "<<<=== Comparing done\n";
|
||||
|
||||
// std::cout << "== The Waveform Samples == \n";
|
||||
// std::cout << v;
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*/
|
||||
static void UnitTestSimple() {
|
||||
std::cout << "=== UnitTestSimple() ===\n";
|
||||
|
||||
Vector<BaseFloat> v(100000);
|
||||
Matrix<BaseFloat> m;
|
||||
|
||||
// init with noise
|
||||
for (int32 i = 0; i < v.Dim(); i++) {
|
||||
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
|
||||
}
|
||||
|
||||
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
|
||||
// the parametrization object
|
||||
MfccOptions op;
|
||||
// trying to have same opts as baseline.
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "rectangular";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
// use default parameters
|
||||
|
||||
// compute mfccs.
|
||||
mfcc.Compute(v, 1.0, &m);
|
||||
|
||||
// possibly dump
|
||||
// std::cout << "== Output features == \n" << m;
|
||||
std::cout << "Test passed :)\n\n";
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare1() {
|
||||
std::cout << "=== UnitTestHTKCompare1() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.1",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.1",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.1");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare2() {
|
||||
std::cout << "=== UnitTestHTKCompare2() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.2",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (i_old != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.2",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.2");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare3() {
|
||||
std::cout << "=== UnitTestHTKCompare3() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.3",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 20.0;
|
||||
// op.mel_opts.debug_mel = true;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.3",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.3");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare4() {
|
||||
std::cout << "=== UnitTestHTKCompare4() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.4",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.4",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.4");
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestHTKCompare5() {
|
||||
std::cout << "=== UnitTestHTKCompare5() ===\n";
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.5",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = true; // Use energy.
|
||||
op.mel_opts.low_freq = 0.0;
|
||||
op.mel_opts.vtln_low = 100.0;
|
||||
op.mel_opts.vtln_high = 7500.0;
|
||||
op.mel_opts.htk_mode = true;
|
||||
|
||||
BaseFloat vtln_warp =
|
||||
1.1; // our approach identical to htk for warp factor >1,
|
||||
// differs slightly for higher mel bins if warp_factor <0.9
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.5",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.5");
|
||||
}
|
||||
|
||||
static void UnitTestHTKCompare6() {
|
||||
std::cout << "=== UnitTestHTKCompare6() ===\n";
|
||||
|
||||
|
||||
std::ifstream is("test_data/test.wav", std::ios_base::binary);
|
||||
WaveData wave;
|
||||
wave.Read(is);
|
||||
KALDI_ASSERT(wave.Data().NumRows() == 1);
|
||||
SubVector<BaseFloat> waveform(wave.Data(), 0);
|
||||
|
||||
// read the HTK features
|
||||
Matrix<BaseFloat> htk_features;
|
||||
{
|
||||
std::ifstream is("test_data/test.wav.fea_htk.6",
|
||||
std::ios::in | std::ios_base::binary);
|
||||
bool ans = ReadHtk(is, &htk_features, 0);
|
||||
KALDI_ASSERT(ans);
|
||||
}
|
||||
|
||||
// use mfcc with default configuration...
|
||||
MfccOptions op;
|
||||
op.frame_opts.dither = 0.0;
|
||||
op.frame_opts.preemph_coeff = 0.97;
|
||||
op.frame_opts.window_type = "hamming";
|
||||
op.frame_opts.remove_dc_offset = false;
|
||||
op.frame_opts.round_to_power_of_two = true;
|
||||
op.mel_opts.num_bins = 24;
|
||||
op.mel_opts.low_freq = 125.0;
|
||||
op.mel_opts.high_freq = 7800.0;
|
||||
op.htk_compat = true;
|
||||
op.use_energy = false; // C0 not energy.
|
||||
|
||||
Mfcc mfcc(op);
|
||||
|
||||
// calculate kaldi features
|
||||
Matrix<BaseFloat> kaldi_raw_features;
|
||||
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
|
||||
|
||||
DeltaFeaturesOptions delta_opts;
|
||||
Matrix<BaseFloat> kaldi_features;
|
||||
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
|
||||
|
||||
// compare the results
|
||||
bool passed = true;
|
||||
int32 i_old = -1;
|
||||
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
|
||||
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
|
||||
// Ignore ends-- we make slightly different choices than
|
||||
// HTK about how to treat the deltas at the ends.
|
||||
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
|
||||
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
|
||||
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
|
||||
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
|
||||
// print the non-matching data only once per-line
|
||||
if (static_cast<int32>(i_old) != i) {
|
||||
std::cout << "\n\n\n[HTK-row: " << i << "] "
|
||||
<< htk_features.Row(i) << "\n";
|
||||
std::cout << "[Kaldi-row: " << i << "] "
|
||||
<< kaldi_features.Row(i) << "\n\n\n";
|
||||
i_old = i;
|
||||
}
|
||||
// print indices of non-matching cells
|
||||
std::cout << "[" << i << ", " << j << "]";
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!passed) KALDI_ERR << "Test failed";
|
||||
|
||||
// write the htk features for later inspection
|
||||
HtkHeader header = {
|
||||
kaldi_features.NumRows(),
|
||||
100000, // 10ms
|
||||
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
|
||||
021406 // MFCC_D_A_0
|
||||
};
|
||||
{
|
||||
std::ofstream os("tmp.test.wav.fea_kaldi.6",
|
||||
std::ios::out | std::ios::binary);
|
||||
WriteHtk(os, kaldi_features, header);
|
||||
}
|
||||
|
||||
std::cout << "Test passed :)\n\n";
|
||||
|
||||
unlink("tmp.test.wav.fea_kaldi.6");
|
||||
}
|
||||
|
||||
void UnitTestVtln() {
|
||||
// Test the function VtlnWarpFreq.
|
||||
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
|
||||
vtln_high_cutoff = 7400;
|
||||
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq),
|
||||
freq / warp_factor);
|
||||
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
low_freq),
|
||||
low_freq);
|
||||
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
high_freq),
|
||||
high_freq);
|
||||
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
|
||||
freq3 = freq2 +
|
||||
(high_freq - freq2) * RandUniform(); // freq3>=freq2
|
||||
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq2);
|
||||
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
warp_factor,
|
||||
freq3);
|
||||
KALDI_ASSERT(w3 >= w2); // increasing function.
|
||||
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
|
||||
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
|
||||
AssertEqual(w3dash, freq3);
|
||||
}
|
||||
}
|
||||
|
||||
static void UnitTestFeat() {
|
||||
UnitTestVtln();
|
||||
UnitTestReadWave();
|
||||
UnitTestSimple();
|
||||
UnitTestHTKCompare1();
|
||||
UnitTestHTKCompare2();
|
||||
// commenting out this one as it doesn't compare right now I normalized
|
||||
// the way the FFT bins are treated (removed offset of 0.5)... this seems
|
||||
// to relate to the way frequency zero behaves.
|
||||
UnitTestHTKCompare3();
|
||||
UnitTestHTKCompare4();
|
||||
UnitTestHTKCompare5();
|
||||
UnitTestHTKCompare6();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
try {
|
||||
for (int i = 0; i < 5; i++) UnitTestFeat();
|
||||
std::cout << "Tests succeeded.\n";
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return 1;
|
||||
}
|
||||
}
|
@ -1,270 +0,0 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "kaldi/feat/wave-reader.h"
|
||||
#include "kaldi/util/kaldi-io.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
|
||||
#include "frontend/audio/audio_cache.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "frontend/audio/feature_cache.h"
|
||||
#include "frontend/audio/frontend_itf.h"
|
||||
#include "frontend/audio/linear_spectrogram.h"
|
||||
#include "frontend/audio/normalizer.h"
|
||||
|
||||
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
||||
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
||||
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
|
||||
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||
|
||||
|
||||
std::vector<float> mean_{
|
||||
-13730251.531853663, -12982852.199316509, -13673844.299583456,
|
||||
-13089406.559646806, -12673095.524938712, -12823859.223276224,
|
||||
-13590267.158903603, -14257618.467152044, -14374605.116185192,
|
||||
-14490009.21822485, -14849827.158924166, -15354435.470563512,
|
||||
-15834149.206532761, -16172971.985514281, -16348740.496746974,
|
||||
-16423536.699409386, -16556246.263649225, -16744088.772748645,
|
||||
-16916184.08510357, -17054034.840031497, -17165612.509455364,
|
||||
-17255955.470915023, -17322572.527648456, -17408943.862033736,
|
||||
-17521554.799865916, -17620623.254924215, -17699792.395918526,
|
||||
-17723364.411134344, -17741483.4433254, -17747426.888704527,
|
||||
-17733315.928209435, -17748780.160905756, -17808336.883775543,
|
||||
-17895918.671983004, -18009812.59173023, -18098188.66548325,
|
||||
-18195798.958462656, -18293617.62980999, -18397432.92077201,
|
||||
-18505834.787318766, -18585451.8100908, -18652438.235649142,
|
||||
-18700960.306275308, -18734944.58792185, -18737426.313365128,
|
||||
-18735347.165987637, -18738813.444170244, -18737086.848890636,
|
||||
-18731576.2474336, -18717405.44095871, -18703089.25545657,
|
||||
-18691014.546456724, -18692460.568905357, -18702119.628629155,
|
||||
-18727710.621126678, -18761582.72034647, -18806745.835547544,
|
||||
-18850674.8692112, -18884431.510951452, -18919999.992506847,
|
||||
-18939303.799078144, -18952946.273760635, -18980289.22996379,
|
||||
-19011610.17803294, -19040948.61805145, -19061021.429847397,
|
||||
-19112055.53768819, -19149667.414264943, -19201127.05091321,
|
||||
-19270250.82564605, -19334606.883057203, -19390513.336589377,
|
||||
-19444176.259208687, -19502755.000038862, -19544333.014549147,
|
||||
-19612668.183176614, -19681902.19006569, -19771969.951249883,
|
||||
-19873329.723376893, -19996752.59235844, -20110031.131400537,
|
||||
-20231658.612529557, -20319378.894054495, -20378534.45718066,
|
||||
-20413332.089584175, -20438147.844177883, -20443710.248040095,
|
||||
-20465457.02238927, -20488610.969337028, -20516295.16424432,
|
||||
-20541423.795738827, -20553192.874953747, -20573605.50701977,
|
||||
-20577871.61936797, -20571807.008916274, -20556242.38912231,
|
||||
-20542199.30819195, -20521239.063551214, -20519150.80004532,
|
||||
-20527204.80248933, -20536933.769257784, -20543470.522332076,
|
||||
-20549700.089992985, -20551525.24958494, -20554873.406493705,
|
||||
-20564277.65794227, -20572211.740052115, -20574305.69550465,
|
||||
-20575494.450104576, -20567092.577932164, -20549302.929608088,
|
||||
-20545445.11878376, -20546625.326603737, -20549190.03499401,
|
||||
-20554824.947828256, -20568341.378989458, -20577582.331383612,
|
||||
-20577980.519402675, -20566603.03458152, -20560131.592262644,
|
||||
-20552166.469060015, -20549063.06763577, -20544490.562339947,
|
||||
-20539817.82346569, -20528747.715731595, -20518026.24576161,
|
||||
-20510977.844974525, -20506874.36087992, -20506731.11977665,
|
||||
-20510482.133420516, -20507760.92101862, -20494644.834457114,
|
||||
-20480107.89304893, -20461312.091867123, -20442941.75080173,
|
||||
-20426123.02834838, -20424607.675283, -20426810.369107097,
|
||||
-20434024.50097819, -20437404.75544205, -20447688.63916367,
|
||||
-20460893.335563846, -20482922.735127095, -20503610.119434915,
|
||||
-20527062.76448319, -20557830.035128627, -20593274.72068722,
|
||||
-20632528.452965066, -20673637.471334763, -20733106.97143075,
|
||||
-20842921.0447562, -21054357.83621519, -21416569.534189366,
|
||||
-21978460.272811692, -22753170.052172784, -23671344.10563395,
|
||||
-24613499.293358143, -25406477.12230188, -25884377.82156489,
|
||||
-26049040.62791664, -26996879.104431007};
|
||||
std::vector<float> variance_{
|
||||
213747175.10846674, 188395815.34302503, 212706429.10966414,
|
||||
199109025.81461075, 189235901.23864496, 194901336.53253657,
|
||||
217481594.29306737, 238689869.12327808, 243977501.24115244,
|
||||
248479623.6431067, 259766741.47116545, 275516766.7790273,
|
||||
291271202.3691234, 302693239.8220509, 308627358.3997694,
|
||||
311143911.38788426, 315446105.07731867, 321705430.9341829,
|
||||
327458907.4659941, 332245072.43223983, 336251717.5935284,
|
||||
339694069.7639722, 342188204.4322228, 345587110.31313115,
|
||||
349903086.2875232, 353660214.20643026, 356700344.5270885,
|
||||
357665362.3529641, 358493352.05658793, 358857951.620328,
|
||||
358375239.52774596, 358899733.6342954, 361051818.3511561,
|
||||
364361716.05025816, 368750322.3771452, 372047800.6462831,
|
||||
375655861.1349018, 379358519.1980013, 383327605.3935181,
|
||||
387458599.282341, 390434692.3406868, 392994486.35057056,
|
||||
394874418.04603153, 396230525.79763395, 396365592.0414835,
|
||||
396334819.8242737, 396488353.19250053, 396438877.00744957,
|
||||
396197980.4459586, 395590921.6672991, 395001107.62072515,
|
||||
394528291.7318225, 394593110.424006, 395018405.59353715,
|
||||
396110577.5415993, 397506704.0371068, 399400197.4657644,
|
||||
401243568.2468382, 402687134.7805103, 404136047.2872507,
|
||||
404883170.001883, 405522253.219517, 406660365.3626476,
|
||||
407919346.0991902, 409045348.5384909, 409759588.7889818,
|
||||
411974821.8564483, 413489718.78201455, 415535392.56684107,
|
||||
418466481.97674364, 421104678.35678065, 423405392.5200779,
|
||||
425550570.40798235, 427929423.9579701, 429585274.253478,
|
||||
432368493.55181056, 435193587.13513297, 438886855.20476013,
|
||||
443058876.8633751, 448181232.5093362, 452883835.6332396,
|
||||
458056721.77926534, 461816531.22735566, 464363620.1970998,
|
||||
465886343.5057493, 466928872.0651, 467180536.42647296,
|
||||
468111848.70714295, 469138695.3071312, 470378429.6930793,
|
||||
471517958.7132626, 472109050.4262365, 473087417.0177867,
|
||||
473381322.04648733, 473220195.85483915, 472666071.8998819,
|
||||
472124669.87879956, 471298571.411737, 471251033.2902761,
|
||||
471672676.43128747, 472177147.2193172, 472572361.7711908,
|
||||
472968783.7751127, 473156295.4164052, 473398034.82676554,
|
||||
473897703.5203811, 474328271.33112127, 474452670.98002136,
|
||||
474549003.99284613, 474252887.13567275, 473557462.909069,
|
||||
473483385.85193115, 473609738.04855174, 473746944.82085115,
|
||||
474016729.91696435, 474617321.94138587, 475045097.237122,
|
||||
475125402.586558, 474664112.9824912, 474426247.5800283,
|
||||
474104075.42796475, 473978219.7273978, 473773171.7798875,
|
||||
473578534.69508696, 473102924.16904145, 472651240.5232615,
|
||||
472374383.1810912, 472209479.6956096, 472202298.8921673,
|
||||
472370090.76781124, 472220933.99374026, 471625467.37106377,
|
||||
470994646.51883453, 470182428.9637543, 469348211.5939578,
|
||||
468570387.4467277, 468540442.7225135, 468672018.90414184,
|
||||
468994346.9533251, 469138757.58201426, 469553915.95710236,
|
||||
470134523.38582784, 471082421.62055486, 471962316.51804745,
|
||||
472939745.1708408, 474250621.5944825, 475773933.43199486,
|
||||
477465399.71087736, 479218782.61382693, 481752299.7930922,
|
||||
486608947.8984568, 496119403.2067917, 512730085.5704984,
|
||||
539048915.2641417, 576285298.3548826, 621610270.2240586,
|
||||
669308196.4436442, 710656993.5957186, 736344437.3725077,
|
||||
745481288.0241544, 801121432.9925804};
|
||||
int count_ = 912592;
|
||||
|
||||
void WriteMatrix() {
|
||||
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
|
||||
for (size_t idx = 0; idx < mean_.size(); ++idx) {
|
||||
cmvn_stats(0, idx) = mean_[idx];
|
||||
cmvn_stats(1, idx) = variance_[idx];
|
||||
}
|
||||
cmvn_stats(0, mean_.size()) = count_;
|
||||
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, false);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||
FLAGS_wav_rspecifier);
|
||||
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||
WriteMatrix();
|
||||
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
// feature pipeline: wave cache --> decibel_normalizer --> hanning
|
||||
// window -->linear_spectrogram --> global cmvn -> feat cache
|
||||
|
||||
// std::unique_ptr<ppspeech::FrontendInterface> data_source(new
|
||||
// ppspeech::DataCache());
|
||||
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||
new ppspeech::AudioCache());
|
||||
|
||||
ppspeech::DecibelNormalizerOptions db_norm_opt;
|
||||
std::unique_ptr<ppspeech::FrontendInterface> db_norm(
|
||||
new ppspeech::DecibelNormalizer(db_norm_opt, std::move(data_source)));
|
||||
|
||||
ppspeech::LinearSpectrogramOptions opt;
|
||||
opt.frame_opts.frame_length_ms = 20;
|
||||
opt.frame_opts.frame_shift_ms = 10;
|
||||
opt.streaming_chunk = FLAGS_streaming_chunk;
|
||||
opt.frame_opts.dither = 0.0;
|
||||
opt.frame_opts.remove_dc_offset = false;
|
||||
opt.frame_opts.window_type = "hanning";
|
||||
opt.frame_opts.preemph_coeff = 0.0;
|
||||
LOG(INFO) << "frame length (ms): " << opt.frame_opts.frame_length_ms;
|
||||
LOG(INFO) << "frame shift (ms): " << opt.frame_opts.frame_shift_ms;
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> linear_spectrogram(
|
||||
new ppspeech::LinearSpectrogram(opt, std::move(db_norm)));
|
||||
|
||||
std::unique_ptr<ppspeech::FrontendInterface> cmvn(new ppspeech::CMVN(
|
||||
FLAGS_cmvn_write_path, std::move(linear_spectrogram)));
|
||||
|
||||
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
|
||||
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
||||
|
||||
int sample_rate = 16000;
|
||||
float streaming_chunk = FLAGS_streaming_chunk;
|
||||
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||
LOG(INFO) << "sr: " << sample_rate;
|
||||
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||
|
||||
|
||||
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||
std::string utt = wav_reader.Key();
|
||||
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
|
||||
int32 this_channel = 0;
|
||||
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||
this_channel);
|
||||
int tot_samples = waveform.Dim();
|
||||
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||
|
||||
int sample_offset = 0;
|
||||
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||
int feature_rows = 0;
|
||||
while (sample_offset < tot_samples) {
|
||||
int cur_chunk_size =
|
||||
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||
|
||||
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||
wav_chunk(i) = waveform(sample_offset + i);
|
||||
}
|
||||
|
||||
kaldi::Vector<BaseFloat> features;
|
||||
feature_cache.Accept(wav_chunk);
|
||||
if (cur_chunk_size < chunk_sample_size) {
|
||||
feature_cache.SetFinished();
|
||||
}
|
||||
feature_cache.Read(&features);
|
||||
if (features.Dim() == 0) break;
|
||||
|
||||
feats.push_back(features);
|
||||
sample_offset += cur_chunk_size;
|
||||
feature_rows += features.Dim() / feature_cache.Dim();
|
||||
}
|
||||
|
||||
int cur_idx = 0;
|
||||
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
||||
feature_cache.Dim());
|
||||
for (auto feat : feats) {
|
||||
int num_rows = feat.Dim() / feature_cache.Dim();
|
||||
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
||||
++col_idx) {
|
||||
features(cur_idx, col_idx) =
|
||||
feat(row_idx * feature_cache.Dim() + col_idx);
|
||||
}
|
||||
++cur_idx;
|
||||
}
|
||||
}
|
||||
feat_writer.Write(utt, features);
|
||||
feature_cache.Reset();
|
||||
|
||||
if (num_done % 50 == 0 && num_done != 0)
|
||||
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||
num_done++;
|
||||
}
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. ./path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
feat_wspecifier=./feats.ark
|
||||
cmvn=./cmvn.ark
|
||||
|
||||
# 3. run feat
|
||||
export GLOG_logtostderr=1
|
||||
linear_spectrogram_main \
|
||||
--wav_rspecifier=scp:$model_dir/wav.scp \
|
||||
--feature_wspecifier=ark,t:$feat_wspecifier \
|
||||
--cmvn_write_path=$cmvn
|
@ -0,0 +1,2 @@
|
||||
data
|
||||
exp
|
@ -0,0 +1 @@
|
||||
# NGram Train
|
@ -1,5 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
|
||||
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
set +x
|
||||
set -e
|
||||
|
||||
. path.sh
|
||||
|
||||
# 1. compile
|
||||
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
|
||||
pushd ${SPEECHX_ROOT}
|
||||
bash build.sh
|
||||
popd
|
||||
fi
|
||||
|
||||
# 2. download model
|
||||
if [ ! -d ../paddle_asr_model ]; then
|
||||
wget https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/paddle_asr_model.tar.gz
|
||||
tar xzfv paddle_asr_model.tar.gz
|
||||
mv ./paddle_asr_model ../
|
||||
# produce wav scp
|
||||
echo "utt1 " $PWD/../paddle_asr_model/BAC009S0764W0290.wav > ../paddle_asr_model/wav.scp
|
||||
fi
|
||||
|
||||
model_dir=../paddle_asr_model
|
||||
|
||||
# 4. run decoder
|
||||
pp-model-test \
|
||||
--model_path=$model_dir/avg_1.jit.pdmodel \
|
||||
--param_path=$model_dir/avg_1.jit.pdparams
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# CopyRight WeNet Apache-2.0 License
|
||||
|
||||
import re, sys, unicodedata
|
||||
import codecs
|
Loading…
Reference in new issue