Merge pull request #1541 from zh794390558/egs

[speechx] move test to examples
pull/1559/head
YangZhou 3 years ago committed by GitHub
commit ebc2aca990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
return MultiSequential(* [fn(n) for n in range(N)])

@ -117,6 +117,7 @@ set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
@ -137,4 +138,7 @@ set(DEPS ${DEPS}
#target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target)
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)

@ -16,7 +16,7 @@ if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifact
echo -e "\n"
fi
rm -rf build
#rm -rf build
mkdir -p build
cd build

@ -18,6 +18,8 @@ ExternalProject_Add(
SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
CMAKE_GENERATOR "Unix Makefiles")
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)

@ -0,0 +1,5 @@
# Examples
* decoder - offline decoder
* feat - mfcc, linear
* nnet - ds2 nn

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc)
target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet));
// int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear-spectrogram-main ${CMAKE_CURRENT_SOURCE_DIR}/linear-spectrogram-main.cc)
target_include_directories(linear-spectrogram-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear-spectrogram-main frontend kaldi-util kaldi-feat-common gflags glog)

@ -0,0 +1,720 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -0,0 +1,257 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(
opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})

@ -0,0 +1,193 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <thread>
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
// run
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
// output tensor
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
// probs
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}

@ -31,15 +31,3 @@ ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(model_test codelab/nnet_test/model_test.cc)
target_link_libraries(model_test PUBLIC nnet gflags ${DEPS})

@ -43,18 +43,18 @@ typedef unsigned long long uint64;
typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80);
const int8 kint8max = (( int8) 0x7F);
const int16 kint16min = (( int16) 0x8000);
const int16 kint16max = (( int16) 0x7FFF);
const int32 kint32min = (( int32) 0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL));
const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = ((int8)0x80);
const int8 kint8max = ((int8)0x7F);
const int16 kint16min = ((int16)0x8000);
const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -15,22 +15,22 @@
#pragma once
#include <deque>
#include <fstream>
#include <iostream>
#include <istream>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <ostream>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <mutex>
#include <vector>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h"

@ -23,28 +23,29 @@
#ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool {
public:
public:
ThreadPool(size_t);
template<class F, class... Args>
template <class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
std::vector<std::thread> workers;
// the task queue
std::queue< std::function<void()> > tasks;
std::queue<std::function<void()>> tasks;
// synchronization
std::mutex queue_mutex;
@ -53,68 +54,57 @@ private:
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
this->condition.wait(lock, [this] {
return this->stop || !this->tasks.empty();
});
if (this->stop && this->tasks.empty()) return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
});
}
// add new work item to the pool
template<class F, class... Args>
template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
inline ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
for (std::thread& worker : workers) worker.join();
}
#endif

@ -1,4 +0,0 @@
# codelab
This directory is here for testing some funcitons temporaril.

@ -1,57 +0,0 @@
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
//void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet));
//int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,686 +0,0 @@
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "feat/feature-mfcc.h"
#include "base/kaldi-math.h"
#include "matrix/kaldi-matrix-inl.h"
#include "feat/wave-reader.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout << "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input(
"test_data/test_matlab.ascii"
);
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
//op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800,
vtln_low_cutoff = 20, vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(),
freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, 1.0,
freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++)
UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -1,125 +0,0 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007};
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size()+ 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,134 +0,0 @@
#include "paddle_inference_api.h"
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;;
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
config.DisableFCPadding();
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
std::vector<float> paddle_input_feature_matrix;
for(const auto& item : feats) {
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
bool success = predictor->Run();
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35;
int col_size = 161;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h"
struct DecoderResult {

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
@ -9,25 +23,23 @@ namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
opts_(opts),
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
num_frame_decoded_(0),
root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
opts_.beta,
opts_.lm_path,
vocabulary_);
init_ext_scorer_ = std::make_shared<Scorer>(
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
void CTCBeamSearch::Reset() {
@ -36,7 +48,6 @@ void CTCBeamSearch::Reset() {
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
@ -62,23 +74,24 @@ void CTCBeamSearch::InitDecoder() {
}
}
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
void CTCBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCBeamSearch::NumFrameDecoded() {
return num_frame_decoded_;
}
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
// todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break;
}
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
likelihood.push_back(
decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood);
max_frames--;
}
@ -93,12 +106,13 @@ void CTCBeamSearch::ResetPrefixes() {
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) {
kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0;
}
@ -124,12 +138,13 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]);
}
}
@ -141,7 +156,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
@ -181,7 +197,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
int32 CTCBeamSearch::SearchOneChar(
const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size;
@ -196,10 +213,8 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
@ -207,9 +222,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
@ -228,7 +241,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
@ -247,8 +260,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
@ -258,9 +270,7 @@ void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
@ -274,8 +284,8 @@ void CTCBeamSearch::CalculateApproxScore() {
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
@ -283,13 +293,15 @@ void CTCBeamSearch::CalculateApproxScore() {
void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once
@ -17,26 +31,27 @@ struct CTCBeamSearchOptions {
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
CTCBeamSearchOptions()
: dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
num_proc_bsearch(0) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
@ -50,11 +65,13 @@ class CTCBeamSearch {
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
@ -66,7 +83,7 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_;
// std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id;
int space_id;

@ -24,7 +24,8 @@ class FbankExtractor : FeatureExtractorInterface {
public:
explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -21,7 +21,8 @@ namespace ppspeech {
class FeatureExtractorInterface {
public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0;
};

@ -25,7 +25,7 @@ using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
//todo remove later
// todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;

@ -1,27 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech {
struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions():
frame_opts() {}
LinearSpectrogramOptions() : frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
frame_opts.Register(opts);
}
void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
};
class LinearSpectrogram : public FeatureExtractorInterface {
public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts,
explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h"
@ -16,7 +30,8 @@ DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
dim_ = 0;
}
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) {
void DecibelNormalizer::AcceptWaveform(
const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim();
waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
@ -27,7 +42,7 @@ void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
Compute(waveform_, feat);
}
//todo remove later
// todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) {
if (input.Dim() == 0) return;
@ -61,7 +76,7 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
}
// square
for (auto &d : samples) {
for (auto& d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
}
@ -74,14 +89,15 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB,"
LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB.";
return false;
}
// Note that this is an in-place transformation.
for (auto &item : samples) {
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
@ -100,21 +116,20 @@ void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return;
}
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) {
return;
}
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
// feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) {
KALDI_ERR << "Dim mismatch: cmvn "
<< stats_.NumRows() << 'x' << stats_.NumCols()
<< ", feats " << feats->Dim() << 'x';
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR << "You requested variance normalization but no variance stats_ "
KALDI_ERR
<< "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
@ -122,17 +137,20 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: "
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
"normalization: "
<< "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
//the dim of feats = dim * num_frames;
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim*idx, dim);
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats);
}
offset.AddVec(-1.0 / count, mean_stats_apply);
@ -144,18 +162,18 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) {
double mean, offset, scale;
mean = stats_(0, d)/count;
double var = (stats_(1, d)/count) - mean*mean,
floor = 1.0e-20;
mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1/scale == 0.0)
KALDI_ERR << "NaN or infinity in cepstral mean/variance computation";
offset = -(mean*scale);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;

@ -1,10 +1,24 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech {
@ -12,26 +26,30 @@ struct DecibelNormalizerOptions {
float target_db;
float max_gain_db;
bool convert_int_float;
DecibelNormalizerOptions() :
target_db(-20),
max_gain_db(300.0),
convert_int_float(false) {}
DecibelNormalizerOptions()
: target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization");
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float");
opts->Register(
"target-db", &target_db, "target db for db normalization");
opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
}
};
class DecibelNormalizer : public FeatureExtractorInterface {
public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private:
DecibelNormalizerOptions opts_;
size_t dim_;
@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface {
public:
explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
private:
kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_;

@ -13,4 +13,3 @@
// limitations under the License.
// extract the window of kaldi feat.

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
There are two ways these are used: the old online-decoding code, in
../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
control endpointing. In the "new" decoding code, you don't call (for
example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
@ -68,7 +87,8 @@ namespace kaldi {
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
@ -82,30 +102,39 @@ namespace kaldi {
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// The call NumFramesReady() will return the number of frames currently
/// available
/// for this decodable object. This is for use in setups where you don't
/// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
KALDI_ERR
<< "NumFramesReady() not implemented for this decodable type.";
return -1;
}

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h"
namespace ppspeech {
@ -5,18 +19,14 @@ namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet):
frontend_(NULL),
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
: frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows();
}
//Decodable::Init(DecodableConfig config) {
// Decodable::Init(DecodableConfig config) {
//}
bool Decodable::IsLastFrame(int32 frame) const {
@ -24,18 +34,14 @@ bool Decodable::IsLastFrame(int32 frame) const {
return finished_ && (frame == frames_ready_ - 1);
}
int32 Decodable::NumIndices() const {
return 0;
}
int32 Decodable::NumIndices() const { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows();
return ;
return;
}
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {

@ -1,7 +1,21 @@
#include "nnet/decodable-itf.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h"
namespace ppspeech {
@ -11,15 +25,18 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later
void Acceptlikelihood(
const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset();
void InputFinished() { finished_ = true; }
private:
std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
@ -10,10 +24,9 @@ namespace ppspeech {
class NnetInterface {
public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0;
kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void Reset() = 0;
virtual ~NnetInterface() {}
};
} // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h"
@ -21,18 +35,18 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(),
std::transform(tmp_shape.begin(),
tmp_shape.end(),
std::back_inserter(cur_shape),
[](const std::string& s) {
return atoi(s.c_str());
});
[](const std::string& s) { return atoi(s.c_str()); });
cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape);
std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout);
}
}
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) {
@ -45,7 +59,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
if (opts.enable_profile) {
config.EnableProfile();
}
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num));
pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed";
}
@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) {
for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]);
}
ReleasePredictor(predictor);
InitCacheEncouts(opts);
}
void PaddleNnet::Reset() {
InitCacheEncouts(opts_);
}
void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
@ -130,13 +143,14 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return cache_encouts_[iter->second];
}
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) {
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row*col);
feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
@ -146,22 +160,26 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]);
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]);
std::unique_ptr<paddle_infer::Tensor> h_box =
predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]);
std::unique_ptr<paddle_infer::Tensor> c_box =
predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
@ -172,10 +190,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
}
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
@ -187,13 +207,14 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row*col);
inferences_result.resize(row * col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx];
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
}
}

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h"
@ -24,19 +38,27 @@ struct ModelOptions {
std::string cache_shape;
bool enable_fc_padding;
bool enable_profile;
ModelOptions() :
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"),
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"),
ModelOptions()
: model_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
thread_num(2),
use_gpu(false),
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"),
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"),
input_names(
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"),
output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {
}
enable_profile(false) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path");
@ -47,37 +69,37 @@ struct ModelOptions {
opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option");
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option");
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option");
opts->Register("switch-ir-optiom",
&switch_ir_optim,
"paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
}
};
template<typename T>
template <typename T>
class Tensor {
public:
Tensor() {
}
Tensor(const std::vector<int>& shape) :
_shape(shape) {
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
public:
Tensor() {}
Tensor(const std::vector<int>& shape) : _shape(shape) {
int data_size = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0);
}
void reshape(const std::vector<int>& shape) {
_shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
int data_size = std::accumulate(
_shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(data_size, 0);
}
const std::vector<int>& get_shape() const {
return _shape;
}
std::vector<T>& get_data() {
return _data;
}
private:
const std::vector<int>& get_shape() const { return _shape; }
std::vector<T>& get_data() { return _data; }
private:
std::vector<int> _shape;
std::vector<T> _data;
};
@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name);
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts);
private:

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h"
namespace ppspeech {
@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return true;
}
}

@ -1,8 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
namespace ppspeech {
bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data);
}

Loading…
Cancel
Save