Merge pull request #1541 from zh794390558/egs

[speechx] move test to examples
pull/1559/head
YangZhou 3 years ago committed by GitHub
commit ebc2aca990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -117,6 +117,7 @@ set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include") include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
@ -137,4 +138,7 @@ set(DEPS ${DEPS}
#target_link_libraries(lib_name item0 item1) #target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target) #add_dependencies(lib_name depend-target)
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx) add_subdirectory(speechx)
add_subdirectory(examples)

@ -16,7 +16,7 @@ if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifact
echo -e "\n" echo -e "\n"
fi fi
rm -rf build #rm -rf build
mkdir -p build mkdir -p build
cd build cd build

@ -18,6 +18,8 @@ ExternalProject_Add(
SOURCE_DIR ${OpenBLAS_SOURCE_DIR} SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR> CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
CMAKE_GENERATOR "Unix Makefiles") CMAKE_GENERATOR "Unix Makefiles")
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition # https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)

@ -0,0 +1,5 @@
# Examples
* decoder - offline decoder
* feat - mfcc, linear
* nnet - ds2 nn

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc)
target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet));
// int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear-spectrogram-main ${CMAKE_CURRENT_SOURCE_DIR}/linear-spectrogram-main.cc)
target_include_directories(linear-spectrogram-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear-spectrogram-main frontend kaldi-util kaldi-feat-common gflags glog)

@ -0,0 +1,720 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -0,0 +1,257 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(
opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})

@ -0,0 +1,193 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <thread>
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
// run
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
// output tensor
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
// probs
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}

@ -31,15 +31,3 @@ ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(model_test codelab/nnet_test/model_test.cc)
target_link_libraries(model_test PUBLIC nnet gflags ${DEPS})

@ -43,18 +43,18 @@ typedef unsigned long long uint64;
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF); const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF); const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF); const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80); const int8 kint8min = ((int8)0x80);
const int8 kint8max = (( int8) 0x7F); const int8 kint8max = ((int8)0x7F);
const int16 kint16min = (( int16) 0x8000); const int16 kint16min = ((int16)0x8000);
const int16 kint16max = (( int16) 0x7FFF); const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = (( int32) 0x80000000); const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF); const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL)); const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -15,22 +15,22 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <fstream>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <ostream> #include <ostream>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <mutex> #include <vector>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h" #include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h" #include "base/macros.h"

@ -23,28 +23,29 @@
#ifndef BASE_THREAD_POOL_H #ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H #define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable> #include <condition_variable>
#include <future>
#include <functional> #include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept> #include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); ThreadPool(size_t);
template<class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool(); ~ThreadPool();
private:
private:
// need to keep track of threads so we can join them // need to keep track of threads so we can join them
std::vector< std::thread > workers; std::vector<std::thread> workers;
// the task queue // the task queue
std::queue< std::function<void()> > tasks; std::queue<std::function<void()>> tasks;
// synchronization // synchronization
std::mutex queue_mutex; std::mutex queue_mutex;
@ -53,68 +54,57 @@ private:
}; };
// the constructor just launches some amount of workers // the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads) inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
: stop(false) for (size_t i = 0; i < threads; ++i)
{ workers.emplace_back([this] {
for(size_t i = 0;i<threads;++i) for (;;) {
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task; std::function<void()> task;
{ {
std::unique_lock<std::mutex> lock(this->queue_mutex); std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, this->condition.wait(lock, [this] {
[this]{ return this->stop || !this->tasks.empty(); }); return this->stop || !this->tasks.empty();
if(this->stop && this->tasks.empty()) });
return; if (this->stop && this->tasks.empty()) return;
task = std::move(this->tasks.front()); task = std::move(this->tasks.front());
this->tasks.pop(); this->tasks.pop();
} }
task(); task();
} }
} });
);
} }
// add new work item to the pool // add new work item to the pool
template<class F, class... Args> template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args) auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> -> std::future<typename std::result_of<F(Args...)>::type> {
{
using return_type = typename std::result_of<F(Args...)>::type; using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >( auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...) std::bind(std::forward<F>(f), std::forward<Args>(args)...));
);
std::future<return_type> res = task->get_future(); std::future<return_type> res = task->get_future();
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool // don't allow enqueueing after stopping the pool
if(stop) if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); }); tasks.emplace([task]() { (*task)(); });
} }
condition.notify_one(); condition.notify_one();
return res; return res;
} }
// the destructor joins all threads // the destructor joins all threads
inline ThreadPool::~ThreadPool() inline ThreadPool::~ThreadPool() {
{
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
stop = true; stop = true;
} }
condition.notify_all(); condition.notify_all();
for(std::thread &worker: workers) for (std::thread& worker : workers) worker.join();
worker.join();
} }
#endif #endif

@ -1,4 +0,0 @@
# codelab
This directory is here for testing some funcitons temporaril.

@ -1,57 +0,0 @@
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
//void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet));
//int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,686 +0,0 @@
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "feat/feature-mfcc.h"
#include "base/kaldi-math.h"
#include "matrix/kaldi-matrix-inl.h"
#include "feat/wave-reader.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout << "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input(
"test_data/test_matlab.ascii"
);
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
//op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800,
vtln_low_cutoff = 20, vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(),
freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, 1.0,
freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++)
UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -1,125 +0,0 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007};
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size()+ 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,134 +0,0 @@
#include "paddle_inference_api.h"
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;;
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
config.DisableFCPadding();
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
std::vector<float> paddle_input_feature_matrix;
for(const auto& item : feats) {
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
bool success = predictor->Run();
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35;
int col_size = 161;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h" #include "base/basic_types.h"
struct DecoderResult { struct DecoderResult {

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/basic_types.h"
@ -9,25 +23,23 @@ namespace ppspeech {
using std::vector; using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_(opts), : opts_(opts),
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id(-1), blank_id(-1),
space_id(-1), space_id(-1),
num_frame_decoded_(0), num_frame_decoded_(0),
root(nullptr) { root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
} }
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size(); LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path; LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha, init_ext_scorer_ = std::make_shared<Scorer>(
opts_.beta, opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
opts_.lm_path,
vocabulary_);
} }
void CTCBeamSearch::Reset() { void CTCBeamSearch::Reset() {
@ -36,7 +48,6 @@ void CTCBeamSearch::Reset() {
} }
void CTCBeamSearch::InitDecoder() { void CTCBeamSearch::InitDecoder() {
blank_id = 0; blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
@ -51,10 +62,11 @@ void CTCBeamSearch::InitDecoder() {
root = std::make_shared<PathTrie>(); root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0; root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get()); prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict = auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary); static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
@ -62,23 +74,24 @@ void CTCBeamSearch::InitDecoder() {
} }
} }
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) { void CTCBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
return num_frame_decoded_;
}
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) { int max_frames) {
while (max_frames > 0) { while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break; break;
} }
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); likelihood.push_back(
decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
max_frames--; max_frames--;
} }
@ -93,12 +106,13 @@ void CTCBeamSearch::ResetPrefixes() {
} }
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset(); timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f; LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0; return 0;
} }
@ -124,12 +138,13 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
double cutoff_prob = opts_.cutoff_prob; double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n; size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0)); vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size(); int row = probs.size();
int col = probs[0].size(); int col = probs[0].size();
for(int i = 0; i < row; i++) { for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){ for (int j = 0; j < col; j++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]); probs_seq[i][j] = static_cast<double>(probs[i][j]);
} }
} }
@ -141,7 +156,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
bool full_beam = false; bool full_beam = false;
if (init_ext_scorer_ != nullptr) { if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare); prefix_compare);
if (num_prefixes == 0) { if (num_prefixes == 0) {
@ -181,7 +197,8 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
} // for probs_seq } // for probs_seq
} }
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam, int32 CTCBeamSearch::SearchOneChar(
const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff) { const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
@ -196,10 +213,8 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
} }
if (c == blank_id) { if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp( prefix->log_prob_b_cur =
prefix->log_prob_b_cur, log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
log_prob_c +
prefix->score);
continue; continue;
} }
@ -207,9 +222,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
if (c == prefix->character) { if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
log_prob_c +
prefix->log_prob_nb_prev);
} }
// get new prefix // get new prefix
@ -228,7 +241,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
// language model scoring // language model scoring
if (init_ext_scorer_ != nullptr && if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) { (c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr; PathTrie* prefix_to_score = nullptr;
// skip scoring the space // skip scoring the space
if (init_ext_scorer_->is_character_based()) { if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new; prefix_to_score = prefix_new;
@ -247,8 +260,7 @@ int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
} }
// p_{nb}(l;x_{1:t}) // p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur = prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
log_p);
} }
} // end of loop over prefix } // end of loop over prefix
return 0; return 0;
@ -258,9 +270,7 @@ void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable. // return order of decoding result. To delete when decoder gets stable.
@ -274,8 +284,8 @@ void CTCBeamSearch::CalculateApproxScore() {
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; init_ext_scorer_->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
} }
@ -283,13 +293,15 @@ void CTCBeamSearch::CalculateApproxScore() {
void CTCBeamSearch::LMRescore() { void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) { if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0; float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix); vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha; score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
score += init_ext_scorer_->beta; score += init_ext_scorer_->beta;
prefix->score += score; prefix->score += score;
} }

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once #pragma once
@ -17,26 +31,27 @@ struct CTCBeamSearchOptions {
int beam_size; int beam_size;
int cutoff_top_n; int cutoff_top_n;
int num_proc_bsearch; int num_proc_bsearch;
CTCBeamSearchOptions() : CTCBeamSearchOptions()
dict_file("./model/words.txt"), : dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"), lm_path("./model/lm.arpa"),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300), beam_size(300),
cutoff_prob(0.99f), cutoff_prob(0.99f),
cutoff_top_n(40), cutoff_top_n(40),
num_proc_bsearch(0) { num_proc_bsearch(0) {}
}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file "); opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file"); opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha"); opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta"); opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method"); opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
} }
}; };
@ -50,11 +65,13 @@ class CTCBeamSearch {
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames); int max_frames);
void Reset(); void Reset();
private: private:
void ResetPrefixes(); void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam, int32 SearchOneChar(const bool& full_beam,
@ -66,7 +83,7 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
//std::vector<DecodeResult> decoder_results_; // std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id; size_t blank_id;
int space_id; int space_id;

@ -24,7 +24,8 @@ class FbankExtractor : FeatureExtractorInterface {
public: public:
explicit FbankExtractor(const FbankOptions& opts, explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor); share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -21,7 +21,8 @@ namespace ppspeech {
class FeatureExtractorInterface { class FeatureExtractorInterface {
public: public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
}; };

@ -25,7 +25,7 @@ using kaldi::VectorBase;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
//todo remove later // todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input, void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;

@ -1,27 +1,40 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h" #include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech { namespace ppspeech {
struct LinearSpectrogramOptions { struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts; kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions(): LinearSpectrogramOptions() : frame_opts() {}
frame_opts() {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
frame_opts.Register(opts);
}
}; };
class LinearSpectrogram : public FeatureExtractorInterface { class LinearSpectrogram : public FeatureExtractorInterface {
public: public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, explicit LinearSpectrogram(
const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor); std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats); void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
@ -16,7 +30,8 @@ DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
dim_ = 0; dim_ = 0;
} }
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) { void DecibelNormalizer::AcceptWaveform(
const kaldi::VectorBase<BaseFloat>& input) {
dim_ = input.Dim(); dim_ = input.Dim();
waveform_.Resize(input.Dim()); waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input); waveform_.CopyFromVec(input);
@ -27,7 +42,7 @@ void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
Compute(waveform_, feat); Compute(waveform_, feat);
} }
//todo remove later // todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input, void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;
@ -61,7 +76,7 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
} }
// square // square
for (auto &d : samples) { for (auto& d : samples) {
if (opts_.convert_int_float) { if (opts_.convert_int_float) {
d = d * wave_float_normlization; d = d * wave_float_normlization;
} }
@ -74,14 +89,15 @@ bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
gain = opts_.target_db - rms_db; gain = opts_.target_db - rms_db;
if (gain > opts_.max_gain_db) { if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB," LOG(ERROR)
<< "Unable to normalize segment to " << opts_.target_db << "dB,"
<< "because the the probable gain have exceeds opts_.max_gain_db" << "because the the probable gain have exceeds opts_.max_gain_db"
<< opts_.max_gain_db << "dB."; << opts_.max_gain_db << "dB.";
return false; return false;
} }
// Note that this is an in-place transformation. // Note that this is an in-place transformation.
for (auto &item : samples) { for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0) // python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0); item *= std::pow(10.0, gain / 20.0);
} }
@ -100,21 +116,20 @@ void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return; return;
} }
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
return;
}
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) { void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1; int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) { if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
KALDI_ERR << "Dim mismatch: cmvn " feats->Dim() % dim != 0) {
<< stats_.NumRows() << 'x' << stats_.NumCols() KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< ", feats " << feats->Dim() << 'x'; << stats_.NumCols() << ", feats " << feats->Dim() << 'x';
} }
if (stats_.NumRows() == 1 && var_norm) { if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR << "You requested variance normalization but no variance stats_ " KALDI_ERR
<< "You requested variance normalization but no variance stats_ "
<< "are supplied."; << "are supplied.";
} }
@ -122,17 +137,20 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one. // computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0) if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: " KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
"normalization: "
<< "count = " << count; << "count = " << count;
if (!var_norm) { if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim()); Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim); SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim()); Vector<double> mean_stats_apply(feats->Dim());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature. // fill the datat of mean_stats in mean_stats_appy whose dim is equal
//the dim of feats = dim * num_frames; // with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim*idx, dim); SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats); stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); offset.AddVec(-1.0 / count, mean_stats_apply);
@ -144,18 +162,18 @@ void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
kaldi::Matrix<BaseFloat> norm(2, feats->Dim()); kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) { for (int32 d = 0; d < dim; d++) {
double mean, offset, scale; double mean, offset, scale;
mean = stats_(0, d)/count; mean = stats_(0, d) / count;
double var = (stats_(1, d)/count) - mean*mean, double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
floor = 1.0e-20;
if (var < floor) { if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to " KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor; << floor;
var = floor; var = floor;
} }
scale = 1.0 / sqrt(var); scale = 1.0 / sqrt(var);
if (scale != scale || 1/scale == 0.0) if (scale != scale || 1 / scale == 0.0)
KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; KALDI_ERR
offset = -(mean*scale); << "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) { for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset; norm(0, d_skip) = offset;
norm(1, d_skip) = scale; norm(1, d_skip) = scale;

@ -1,10 +1,24 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech { namespace ppspeech {
@ -12,26 +26,30 @@ struct DecibelNormalizerOptions {
float target_db; float target_db;
float max_gain_db; float max_gain_db;
bool convert_int_float; bool convert_int_float;
DecibelNormalizerOptions() : DecibelNormalizerOptions()
target_db(-20), : target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
max_gain_db(300.0),
convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization"); opts->Register(
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization"); "target-db", &target_db, "target db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float"); opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
} }
}; };
class DecibelNormalizer : public FeatureExtractorInterface { class DecibelNormalizer : public FeatureExtractorInterface {
public: public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const; kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private: private:
DecibelNormalizerOptions opts_; DecibelNormalizerOptions opts_;
size_t dim_; size_t dim_;
@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface { class CMVN : public FeatureExtractorInterface {
public: public:
explicit CMVN(std::string cmvn_file); explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; } virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test // for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats); void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats); void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
private: private:
kaldi::Matrix<double> stats_; kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_; std::shared_ptr<FeatureExtractorInterface> base_extractor_;

@ -13,4 +13,3 @@
// limitations under the License. // limitations under the License.
// extract the window of kaldi feat. // extract the window of kaldi feat.

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h // itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University; // Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions. important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/, There are two ways these are used: the old online-decoding code, in
and the new online-decoding code, in ../online2/. In the old online-decoding ../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do: code, the decoder would do:
\code{.cc} \code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode \endcode
and the call to IsLastFrame would block if the features had not arrived yet. and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode(). example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example) control endpointing. In the "new" decoding code, you don't call (for
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this: something like this:
@ -68,7 +87,8 @@ namespace kaldi {
} }
\endcode \endcode
So the decodable object never has IsLastFrame() called. For decoding where So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame. return true for the last frame.
@ -82,30 +102,39 @@ namespace kaldi {
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. /// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// frame
/// before calling this. /// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the /// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file /// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so /// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting /// be careful). Caution: the behavior of this function in an online
/// setting
/// is being changed somewhat. In future it may return false in cases where /// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide /// we haven't yet decided to terminate decoding, but later true if we
/// decide
/// to terminate decoding. The plan in future is to rely more on /// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false /// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a /// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA /// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup. /// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0; virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently available /// The call NumFramesReady() will return the number of frames currently
/// for this decodable object. This is for use in setups where you don't want the /// available
/// decoder to block while waiting for input. This is newly added as of Jan 2014, /// for this decodable object. This is for use in setups where you don't
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to /// want the
/// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// and I hope, going forward, to rely on this mechanism more than
/// IsLastFrame to
/// know when to stop decoding. /// know when to stop decoding.
virtual int32 NumFramesReady() const { virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; KALDI_ERR
<< "NumFramesReady() not implemented for this decodable type.";
return -1; return -1;
} }

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h" #include "nnet/decodable.h"
namespace ppspeech { namespace ppspeech {
@ -5,18 +19,14 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet): Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
frontend_(NULL), : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
} }
//Decodable::Init(DecodableConfig config) { // Decodable::Init(DecodableConfig config) {
//} //}
bool Decodable::IsLastFrame(int32 frame) const { bool Decodable::IsLastFrame(int32 frame) const {
@ -24,18 +34,14 @@ bool Decodable::IsLastFrame(int32 frame) const {
return finished_ && (frame == frames_ready_ - 1); return finished_ && (frame == frames_ready_ - 1);
} }
int32 Decodable::NumIndices() const { int32 Decodable::NumIndices() const { return 0; }
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) { void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_); nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return ; return;
} }
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) { std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {

@ -1,7 +1,21 @@
#include "nnet/decodable-itf.h" // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h" #include "nnet/nnet_interface.h"
namespace ppspeech { namespace ppspeech {
@ -11,15 +25,18 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet); explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame); virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later void Acceptlikelihood(
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset(); void Reset();
void InputFinished() { finished_ = true; } void InputFinished() { finished_ = true; }
private: private:
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
@ -10,10 +24,9 @@ namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0; kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
}; };
} // namespace ppspeech } // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
@ -21,18 +35,18 @@ void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> tmp_shape; std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-"); tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape; std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(), std::transform(tmp_shape.begin(),
tmp_shape.end(),
std::back_inserter(cur_shape), std::back_inserter(cur_shape),
[](const std::string& s) { [](const std::string& s) { return atoi(s.c_str()); });
return atoi(s.c_str());
});
cache_names_idx_[cache_names[i]] = i; cache_names_idx_[cache_names[i]] = i;
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape); std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout); cache_encouts_.push_back(cache_eout);
} }
} }
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path); config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) { if (opts.use_gpu) {
@ -45,7 +59,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
if (opts.enable_profile) { if (opts.enable_profile) {
config.EnableProfile(); config.EnableProfile();
} }
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num)); pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) { if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed"; LOG(ERROR) << "create the predictor pool failed";
} }
@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std::vector<std::string> model_output_names = predictor->GetOutputNames(); std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size()); assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) { for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]); assert(output_names_vec[i] == model_output_names[i]);
} }
ReleasePredictor(predictor); ReleasePredictor(predictor);
InitCacheEncouts(opts); InitCacheEncouts(opts);
} }
void PaddleNnet::Reset() { void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
InitCacheEncouts(opts_);
}
paddle_infer::Predictor* PaddleNnet::GetPredictor() { paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl; LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
@ -130,13 +143,14 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return cache_encouts_[iter->second]; return cache_encouts_[iter->second];
} }
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) { void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
Matrix<BaseFloat>* inferences) {
paddle_infer::Predictor* predictor = GetPredictor(); paddle_infer::Predictor* predictor = GetPredictor();
int row = features.NumRows(); int row = features.NumRows();
int col = features.NumCols(); int col = features.NumCols();
std::vector<BaseFloat> feed_feature; std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat // todo refactor feed feature: SmileGoat
feed_feature.reserve(row*col); feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx)); feed_feature.push_back(features(row_idx, col_idx));
@ -146,22 +160,26 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
std::vector<std::string> output_names = predictor->GetOutputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col; LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]); std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col}; std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE); input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data()); input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]); std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1}; std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size); input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len; std::vector<int64_t> audio_len;
audio_len.push_back(row); audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data()); input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]); std::unique_ptr<paddle_infer::Tensor> h_box =
predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]); shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape()); h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data()); h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]); std::unique_ptr<paddle_infer::Tensor> c_box =
predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]); shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape()); c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data()); c_box->CopyFromCpu(c_cache->get_data().data());
@ -172,10 +190,12 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
} }
LOG(INFO) << "get the model success"; LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]); std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape()); assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data()); h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]); std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape()); assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data()); c_out->CopyToCpu(c_cache->get_data().data());
@ -187,13 +207,14 @@ void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat
col = output_shape[2]; col = output_shape[2];
vector<float> inferences_result; vector<float> inferences_result;
inferences->Resize(row, col); inferences->Resize(row, col);
inferences_result.resize(row*col); inferences_result.resize(row * col);
output_tensor->CopyToCpu(inferences_result.data()); output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor); ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) { for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) { for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx]; (*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
} }
} }
} }

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h" #include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
@ -24,19 +38,27 @@ struct ModelOptions {
std::string cache_shape; std::string cache_shape;
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelOptions() : ModelOptions()
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"), : model_path(
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"), "../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
thread_num(2), thread_num(2),
use_gpu(false), use_gpu(false),
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), input_names(
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"), "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
"box"),
output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"), cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"), cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false), switch_ir_optim(false),
enable_fc_padding(false), enable_fc_padding(false),
enable_profile(false) { enable_profile(false) {}
}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path"); opts->Register("model-path", &model_path, "model file path");
@ -47,37 +69,37 @@ struct ModelOptions {
opts->Register("output-names", &output_names, "paddle output names"); opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names"); opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape"); opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option"); opts->Register("switch-ir-optiom",
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option"); &switch_ir_optim,
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option"); "paddle SwitchIrOptim option");
opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
} }
}; };
template<typename T> template <typename T>
class Tensor { class Tensor {
public: public:
Tensor() { Tensor() {}
} Tensor(const std::vector<int>& shape) : _shape(shape) {
Tensor(const std::vector<int>& shape) : int data_size = std::accumulate(
_shape(shape) { _shape.begin(), _shape.end(), 1, std::multiplies<int>());
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size; LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
void reshape(const std::vector<int>& shape) { void reshape(const std::vector<int>& shape) {
_shape = shape; _shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(), int data_size = std::accumulate(
1, std::multiplies<int>()); _shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
const std::vector<int>& get_shape() const { const std::vector<int>& get_shape() const { return _shape; }
return _shape; std::vector<T>& get_data() { return _data; }
}
std::vector<T>& get_data() { private:
return _data;
}
private:
std::vector<int> _shape; std::vector<int> _shape;
std::vector<T> _data; std::vector<T> _data;
}; };
@ -88,7 +110,8 @@ class PaddleNnet : public NnetInterface {
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences); kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset(); virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name); std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
private: private:

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h" #include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return true; return true;
} }
} }

@ -1,8 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
namespace ppspeech { namespace ppspeech {
bool ReadFileToVector(const std::string& filename, bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data); std::vector<std::string>* data);
} }

Loading…
Cancel
Save