Merge pull request #1541 from zh794390558/egs

[speechx] move test to examples
pull/1559/head
YangZhou 3 years ago committed by GitHub
commit ebc2aca990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,13 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook entry: python .pre-commit-hooks/copyright-check.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -117,6 +117,7 @@ set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include") include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
set(DEPS ${DEPS} set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB} ${MATH_LIB} ${MKLDNN_LIB}
@ -137,4 +138,7 @@ set(DEPS ${DEPS}
#target_link_libraries(lib_name item0 item1) #target_link_libraries(lib_name item0 item1)
#add_dependencies(lib_name depend-target) #add_dependencies(lib_name depend-target)
add_subdirectory(speechx) set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)

@ -16,7 +16,7 @@ if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifact
echo -e "\n" echo -e "\n"
fi fi
rm -rf build #rm -rf build
mkdir -p build mkdir -p build
cd build cd build

@ -18,6 +18,8 @@ ExternalProject_Add(
SOURCE_DIR ${OpenBLAS_SOURCE_DIR} SOURCE_DIR ${OpenBLAS_SOURCE_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR> CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
CMAKE_GENERATOR "Unix Makefiles") CMAKE_GENERATOR "Unix Makefiles")
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition # https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)

@ -0,0 +1,5 @@
# Examples
* decoder - offline decoder
* feat - mfcc, linear
* nnet - ds2 nn

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc)
target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet));
// int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc)
target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear-spectrogram-main ${CMAKE_CURRENT_SOURCE_DIR}/linear-spectrogram-main.cc)
target_include_directories(linear-spectrogram-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(linear-spectrogram-main frontend kaldi-util kaldi-feat-common gflags glog)

@ -0,0 +1,720 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-math.h"
#include "feat/feature-mfcc.h"
#include "feat/wave-reader.h"
#include "matrix/kaldi-matrix-inl.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout
<< "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input("test_data/test_matlab.ascii");
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout
<< "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs(i * 433024253) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
// op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp =
1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] "
<< htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] "
<< kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}
}
}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float) * kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out | std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20,
vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(),
freq3 = freq2 +
(high_freq - freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff,
vtln_high_cutoff,
low_freq,
high_freq,
warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++) UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -0,0 +1,257 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{
-13730251.531853663, -12982852.199316509, -13673844.299583456,
-13089406.559646806, -12673095.524938712, -12823859.223276224,
-13590267.158903603, -14257618.467152044, -14374605.116185192,
-14490009.21822485, -14849827.158924166, -15354435.470563512,
-15834149.206532761, -16172971.985514281, -16348740.496746974,
-16423536.699409386, -16556246.263649225, -16744088.772748645,
-16916184.08510357, -17054034.840031497, -17165612.509455364,
-17255955.470915023, -17322572.527648456, -17408943.862033736,
-17521554.799865916, -17620623.254924215, -17699792.395918526,
-17723364.411134344, -17741483.4433254, -17747426.888704527,
-17733315.928209435, -17748780.160905756, -17808336.883775543,
-17895918.671983004, -18009812.59173023, -18098188.66548325,
-18195798.958462656, -18293617.62980999, -18397432.92077201,
-18505834.787318766, -18585451.8100908, -18652438.235649142,
-18700960.306275308, -18734944.58792185, -18737426.313365128,
-18735347.165987637, -18738813.444170244, -18737086.848890636,
-18731576.2474336, -18717405.44095871, -18703089.25545657,
-18691014.546456724, -18692460.568905357, -18702119.628629155,
-18727710.621126678, -18761582.72034647, -18806745.835547544,
-18850674.8692112, -18884431.510951452, -18919999.992506847,
-18939303.799078144, -18952946.273760635, -18980289.22996379,
-19011610.17803294, -19040948.61805145, -19061021.429847397,
-19112055.53768819, -19149667.414264943, -19201127.05091321,
-19270250.82564605, -19334606.883057203, -19390513.336589377,
-19444176.259208687, -19502755.000038862, -19544333.014549147,
-19612668.183176614, -19681902.19006569, -19771969.951249883,
-19873329.723376893, -19996752.59235844, -20110031.131400537,
-20231658.612529557, -20319378.894054495, -20378534.45718066,
-20413332.089584175, -20438147.844177883, -20443710.248040095,
-20465457.02238927, -20488610.969337028, -20516295.16424432,
-20541423.795738827, -20553192.874953747, -20573605.50701977,
-20577871.61936797, -20571807.008916274, -20556242.38912231,
-20542199.30819195, -20521239.063551214, -20519150.80004532,
-20527204.80248933, -20536933.769257784, -20543470.522332076,
-20549700.089992985, -20551525.24958494, -20554873.406493705,
-20564277.65794227, -20572211.740052115, -20574305.69550465,
-20575494.450104576, -20567092.577932164, -20549302.929608088,
-20545445.11878376, -20546625.326603737, -20549190.03499401,
-20554824.947828256, -20568341.378989458, -20577582.331383612,
-20577980.519402675, -20566603.03458152, -20560131.592262644,
-20552166.469060015, -20549063.06763577, -20544490.562339947,
-20539817.82346569, -20528747.715731595, -20518026.24576161,
-20510977.844974525, -20506874.36087992, -20506731.11977665,
-20510482.133420516, -20507760.92101862, -20494644.834457114,
-20480107.89304893, -20461312.091867123, -20442941.75080173,
-20426123.02834838, -20424607.675283, -20426810.369107097,
-20434024.50097819, -20437404.75544205, -20447688.63916367,
-20460893.335563846, -20482922.735127095, -20503610.119434915,
-20527062.76448319, -20557830.035128627, -20593274.72068722,
-20632528.452965066, -20673637.471334763, -20733106.97143075,
-20842921.0447562, -21054357.83621519, -21416569.534189366,
-21978460.272811692, -22753170.052172784, -23671344.10563395,
-24613499.293358143, -25406477.12230188, -25884377.82156489,
-26049040.62791664, -26996879.104431007};
std::vector<float> variance_{
213747175.10846674, 188395815.34302503, 212706429.10966414,
199109025.81461075, 189235901.23864496, 194901336.53253657,
217481594.29306737, 238689869.12327808, 243977501.24115244,
248479623.6431067, 259766741.47116545, 275516766.7790273,
291271202.3691234, 302693239.8220509, 308627358.3997694,
311143911.38788426, 315446105.07731867, 321705430.9341829,
327458907.4659941, 332245072.43223983, 336251717.5935284,
339694069.7639722, 342188204.4322228, 345587110.31313115,
349903086.2875232, 353660214.20643026, 356700344.5270885,
357665362.3529641, 358493352.05658793, 358857951.620328,
358375239.52774596, 358899733.6342954, 361051818.3511561,
364361716.05025816, 368750322.3771452, 372047800.6462831,
375655861.1349018, 379358519.1980013, 383327605.3935181,
387458599.282341, 390434692.3406868, 392994486.35057056,
394874418.04603153, 396230525.79763395, 396365592.0414835,
396334819.8242737, 396488353.19250053, 396438877.00744957,
396197980.4459586, 395590921.6672991, 395001107.62072515,
394528291.7318225, 394593110.424006, 395018405.59353715,
396110577.5415993, 397506704.0371068, 399400197.4657644,
401243568.2468382, 402687134.7805103, 404136047.2872507,
404883170.001883, 405522253.219517, 406660365.3626476,
407919346.0991902, 409045348.5384909, 409759588.7889818,
411974821.8564483, 413489718.78201455, 415535392.56684107,
418466481.97674364, 421104678.35678065, 423405392.5200779,
425550570.40798235, 427929423.9579701, 429585274.253478,
432368493.55181056, 435193587.13513297, 438886855.20476013,
443058876.8633751, 448181232.5093362, 452883835.6332396,
458056721.77926534, 461816531.22735566, 464363620.1970998,
465886343.5057493, 466928872.0651, 467180536.42647296,
468111848.70714295, 469138695.3071312, 470378429.6930793,
471517958.7132626, 472109050.4262365, 473087417.0177867,
473381322.04648733, 473220195.85483915, 472666071.8998819,
472124669.87879956, 471298571.411737, 471251033.2902761,
471672676.43128747, 472177147.2193172, 472572361.7711908,
472968783.7751127, 473156295.4164052, 473398034.82676554,
473897703.5203811, 474328271.33112127, 474452670.98002136,
474549003.99284613, 474252887.13567275, 473557462.909069,
473483385.85193115, 473609738.04855174, 473746944.82085115,
474016729.91696435, 474617321.94138587, 475045097.237122,
475125402.586558, 474664112.9824912, 474426247.5800283,
474104075.42796475, 473978219.7273978, 473773171.7798875,
473578534.69508696, 473102924.16904145, 472651240.5232615,
472374383.1810912, 472209479.6956096, 472202298.8921673,
472370090.76781124, 472220933.99374026, 471625467.37106377,
470994646.51883453, 470182428.9637543, 469348211.5939578,
468570387.4467277, 468540442.7225135, 468672018.90414184,
468994346.9533251, 469138757.58201426, 469553915.95710236,
470134523.38582784, 471082421.62055486, 471962316.51804745,
472939745.1708408, 474250621.5944825, 475773933.43199486,
477465399.71087736, 479218782.61382693, 481752299.7930922,
486608947.8984568, 496119403.2067917, 512730085.5704984,
539048915.2641417, 576285298.3548826, 621610270.2240586,
669308196.4436442, 710656993.5957186, 736344437.3725077,
745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size() + 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(
FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning
// window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(
opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) =
(feat(row_idx, col_idx) - mean_[col_idx]) *
variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows,
feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc)
target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS})

@ -0,0 +1,193 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <thread>
#include "paddle_inference_api.h"
using std::cout;
using std::endl;
DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams");
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35; // chunk_size in frame
int col_size = 161; // feat dim
cout << "chunk size: " << chunk_size << endl;
cout << "feat dim: " << col_size << endl;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;
;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl;
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
cout << "SwitchIrOptim: " << false << endl;
config.DisableFCPadding();
cout << "DisableFCPadding: " << endl;
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size()
<< ",col=" << feats[0].size() << std::endl;
std::vector<float> pp_input_mat;
for (const auto& item : feats) {
pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
for (auto name : input_names) {
cout << "model input names: " << name << endl;
}
for (auto name : output_names) {
cout << "model output names: " << name << endl;
}
// input
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(pp_input_mat.data());
// input length
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
// state_h
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box =
predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size =
std::accumulate(chunk_state_h_box_shape.begin(),
chunk_state_h_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
// state_c
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box =
predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size =
std::accumulate(chunk_state_c_box_shape.begin(),
chunk_state_c_box_shape.end(),
1,
std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
// run
bool success = predictor->Run();
// state_h out
std::unique_ptr<paddle_infer::Tensor> h_out =
predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(
h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
// stage_c out
std::unique_ptr<paddle_infer::Tensor> c_out =
predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(
c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
// output tensor
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
// probs
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size()
<< " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size();
++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}

@ -30,16 +30,4 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc)
target_link_libraries(mfcc-test kaldi-mfcc)
add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc)
target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog)
add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(model_test codelab/nnet_test/model_test.cc)
target_link_libraries(model_test PUBLIC nnet gflags ${DEPS})

@ -18,22 +18,22 @@
#include <limits> #include <limits>
typedef float BaseFloat; typedef float BaseFloat;
typedef double double64; typedef double double64;
typedef signed char int8; typedef signed char int8;
typedef short int16; typedef short int16;
typedef int int32; typedef int int32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64; typedef long int64;
#else #else
typedef long long int64; typedef long long int64;
#endif #endif
typedef unsigned char uint8; typedef unsigned char uint8;
typedef unsigned short uint16; typedef unsigned short uint16;
typedef unsigned int uint32; typedef unsigned int uint32;
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64;
@ -41,20 +41,20 @@ typedef unsigned long uint64;
typedef unsigned long long uint64; typedef unsigned long long uint64;
#endif #endif
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = (( uint8) 0xFF); const uint8 kuint8max = ((uint8)0xFF);
const uint16 kuint16max = ((uint16) 0xFFFF); const uint16 kuint16max = ((uint16)0xFFFF);
const uint32 kuint32max = ((uint32) 0xFFFFFFFF); const uint32 kuint32max = ((uint32)0xFFFFFFFF);
const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL));
const int8 kint8min = (( int8) 0x80); const int8 kint8min = ((int8)0x80);
const int8 kint8max = (( int8) 0x7F); const int8 kint8max = ((int8)0x7F);
const int16 kint16min = (( int16) 0x8000); const int16 kint16min = ((int16)0x8000);
const int16 kint16max = (( int16) 0x7FFF); const int16 kint16max = ((int16)0x7FFF);
const int32 kint32min = (( int32) 0x80000000); const int32 kint32min = ((int32)0x80000000);
const int32 kint32max = (( int32) 0x7FFFFFFF); const int32 kint32max = ((int32)0x7FFFFFFF);
const int64 kint64min = (( int64) (0x8000000000000000LL)); const int64 kint64min = ((int64)(0x8000000000000000LL));
const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL));
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -15,22 +15,22 @@
#pragma once #pragma once
#include <deque> #include <deque>
#include <fstream>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <fstream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <ostream> #include <ostream>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <mutex> #include <vector>
#include "base/log.h"
#include "base/flags.h"
#include "base/basic_types.h" #include "base/basic_types.h"
#include "base/flags.h"
#include "base/log.h"
#include "base/macros.h" #include "base/macros.h"

@ -18,8 +18,8 @@ namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN #ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ #define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \ TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif #endif
} // namespace pp_speech } // namespace pp_speech

@ -23,98 +23,88 @@
#ifndef BASE_THREAD_POOL_H #ifndef BASE_THREAD_POOL_H
#define BASE_THREAD_POOL_H #define BASE_THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable> #include <condition_variable>
#include <future>
#include <functional> #include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept> #include <stdexcept>
#include <thread>
#include <vector>
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); ThreadPool(size_t);
template<class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool(); ~ThreadPool();
private:
private:
// need to keep track of threads so we can join them // need to keep track of threads so we can join them
std::vector< std::thread > workers; std::vector<std::thread> workers;
// the task queue // the task queue
std::queue< std::function<void()> > tasks; std::queue<std::function<void()>> tasks;
// synchronization // synchronization
std::mutex queue_mutex; std::mutex queue_mutex;
std::condition_variable condition; std::condition_variable condition;
bool stop; bool stop;
}; };
// the constructor just launches some amount of workers // the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads) inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
: stop(false) for (size_t i = 0; i < threads; ++i)
{ workers.emplace_back([this] {
for(size_t i = 0;i<threads;++i) for (;;) {
workers.emplace_back( std::function<void()> task;
[this]
{
for(;;)
{ {
std::function<void()> task; std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock, [this] {
{ return this->stop || !this->tasks.empty();
std::unique_lock<std::mutex> lock(this->queue_mutex); });
this->condition.wait(lock, if (this->stop && this->tasks.empty()) return;
[this]{ return this->stop || !this->tasks.empty(); }); task = std::move(this->tasks.front());
if(this->stop && this->tasks.empty()) this->tasks.pop();
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
} }
task();
} }
); });
} }
// add new work item to the pool // add new work item to the pool
template<class F, class... Args> template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args) auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> -> std::future<typename std::result_of<F(Args...)>::type> {
{
using return_type = typename std::result_of<F(Args...)>::type; using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >( auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...) std::bind(std::forward<F>(f), std::forward<Args>(args)...));
);
std::future<return_type> res = task->get_future(); std::future<return_type> res = task->get_future();
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool // don't allow enqueueing after stopping the pool
if(stop) if (stop) throw std::runtime_error("enqueue on stopped ThreadPool");
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); }); tasks.emplace([task]() { (*task)(); });
} }
condition.notify_one(); condition.notify_one();
return res; return res;
} }
// the destructor joins all threads // the destructor joins all threads
inline ThreadPool::~ThreadPool() inline ThreadPool::~ThreadPool() {
{
{ {
std::unique_lock<std::mutex> lock(queue_mutex); std::unique_lock<std::mutex> lock(queue_mutex);
stop = true; stop = true;
} }
condition.notify_all(); condition.notify_all();
for(std::thread &worker: workers) for (std::thread& worker : workers) worker.join();
worker.join();
} }
#endif #endif

@ -1,4 +0,0 @@
# codelab
This directory is here for testing some funcitons temporaril.

@ -1,57 +0,0 @@
// todo refactor, repalce with gtest
#include "decoder/ctc_beam_search_decoder.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "nnet/paddle_nnet.h"
#include "nnet/decodable.h"
DEFINE_string(feature_respecifier, "", "test nnet prob");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
//void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier);
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts;
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
std::shared_ptr<ppspeech::PaddleNnet> nnet(new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::Decodable> decodable(new ppspeech::Decodable(nnet));
//int32 chunk_size = 35;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
decodable->FeedFeatures(feature);
decoder.AdvanceDecode(decodable, 8);
decodable->InputFinished();
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
++num_done;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,686 +0,0 @@
// feat/feature-mfcc-test.cc
// Copyright 2009-2011 Karel Vesely; Petr Motlicek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "feat/feature-mfcc.h"
#include "base/kaldi-math.h"
#include "matrix/kaldi-matrix-inl.h"
#include "feat/wave-reader.h"
using namespace kaldi;
static void UnitTestReadWave() {
std::cout << "=== UnitTestReadWave() ===\n";
Vector<BaseFloat> v, v2;
std::cout << "<<<=== Reading waveform\n";
{
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
const Matrix<BaseFloat> data(wave.Data());
KALDI_ASSERT(data.NumRows() == 1);
v.Resize(data.NumCols());
v.CopyFromVec(data.Row(0));
}
std::cout << "<<<=== Reading Vector<BaseFloat> waveform, prepared by matlab\n";
std::ifstream input(
"test_data/test_matlab.ascii"
);
KALDI_ASSERT(input.good());
v2.Read(input, false);
input.close();
std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n";
KALDI_ASSERT(v.Dim() == v2.Dim());
for (int32 i = 0; i < v.Dim(); i++) {
KALDI_ASSERT(v(i) == v2(i));
}
std::cout << "<<<=== Comparing done\n";
// std::cout << "== The Waveform Samples == \n";
// std::cout << v;
std::cout << "Test passed :)\n\n";
}
/**
*/
static void UnitTestSimple() {
std::cout << "=== UnitTestSimple() ===\n";
Vector<BaseFloat> v(100000);
Matrix<BaseFloat> m;
// init with noise
for (int32 i = 0; i < v.Dim(); i++) {
v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2);
}
std::cout << "<<<=== Just make sure it runs... Nothing is compared\n";
// the parametrization object
MfccOptions op;
// trying to have same opts as baseline.
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "rectangular";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
Mfcc mfcc(op);
// use default parameters
// compute mfccs.
mfcc.Compute(v, 1.0, &m);
// possibly dump
// std::cout << "== Output features == \n" << m;
std::cout << "Test passed :)\n\n";
}
static void UnitTestHTKCompare1() {
std::cout << "=== UnitTestHTKCompare1() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.1",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.1",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.1");
}
static void UnitTestHTKCompare2() {
std::cout << "=== UnitTestHTKCompare2() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.2",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.mel_opts.htk_mode = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (i_old != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.2",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.2");
}
static void UnitTestHTKCompare3() {
std::cout << "=== UnitTestHTKCompare3() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.3",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 20.0;
//op.mel_opts.debug_mel = true;
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.3",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.3");
}
static void UnitTestHTKCompare4() {
std::cout << "=== UnitTestHTKCompare4() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.4",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.low_freq = 0.0;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.htk_mode = true;
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.4",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.4");
}
static void UnitTestHTKCompare5() {
std::cout << "=== UnitTestHTKCompare5() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.5",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.htk_compat = true;
op.use_energy = true; // Use energy.
op.mel_opts.low_freq = 0.0;
op.mel_opts.vtln_low = 100.0;
op.mel_opts.vtln_high = 7500.0;
op.mel_opts.htk_mode = true;
BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1,
// differs slightly for higher mel bins if warp_factor <0.9
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.5",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.5");
}
static void UnitTestHTKCompare6() {
std::cout << "=== UnitTestHTKCompare6() ===\n";
std::ifstream is("test_data/test.wav", std::ios_base::binary);
WaveData wave;
wave.Read(is);
KALDI_ASSERT(wave.Data().NumRows() == 1);
SubVector<BaseFloat> waveform(wave.Data(), 0);
// read the HTK features
Matrix<BaseFloat> htk_features;
{
std::ifstream is("test_data/test.wav.fea_htk.6",
std::ios::in | std::ios_base::binary);
bool ans = ReadHtk(is, &htk_features, 0);
KALDI_ASSERT(ans);
}
// use mfcc with default configuration...
MfccOptions op;
op.frame_opts.dither = 0.0;
op.frame_opts.preemph_coeff = 0.97;
op.frame_opts.window_type = "hamming";
op.frame_opts.remove_dc_offset = false;
op.frame_opts.round_to_power_of_two = true;
op.mel_opts.num_bins = 24;
op.mel_opts.low_freq = 125.0;
op.mel_opts.high_freq = 7800.0;
op.htk_compat = true;
op.use_energy = false; // C0 not energy.
Mfcc mfcc(op);
// calculate kaldi features
Matrix<BaseFloat> kaldi_raw_features;
mfcc.Compute(waveform, 1.0, &kaldi_raw_features);
DeltaFeaturesOptions delta_opts;
Matrix<BaseFloat> kaldi_features;
ComputeDeltas(delta_opts,
kaldi_raw_features,
&kaldi_features);
// compare the results
bool passed = true;
int32 i_old = -1;
KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows());
KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols());
// Ignore ends-- we make slightly different choices than
// HTK about how to treat the deltas at the ends.
for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) {
for (int32 j = 0; j < kaldi_features.NumCols(); j++) {
BaseFloat a = kaldi_features(i, j), b = htk_features(i, j);
if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!!
// print the non-matching data only once per-line
if (static_cast<int32>(i_old) != i) {
std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n";
std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n";
i_old = i;
}
// print indices of non-matching cells
std::cout << "[" << i << ", " << j << "]";
passed = false;
}}}
if (!passed) KALDI_ERR << "Test failed";
// write the htk features for later inspection
HtkHeader header = {
kaldi_features.NumRows(),
100000, // 10ms
static_cast<int16>(sizeof(float)*kaldi_features.NumCols()),
021406 // MFCC_D_A_0
};
{
std::ofstream os("tmp.test.wav.fea_kaldi.6",
std::ios::out|std::ios::binary);
WriteHtk(os, kaldi_features, header);
}
std::cout << "Test passed :)\n\n";
unlink("tmp.test.wav.fea_kaldi.6");
}
void UnitTestVtln() {
// Test the function VtlnWarpFreq.
BaseFloat low_freq = 10, high_freq = 7800,
vtln_low_cutoff = 20, vtln_high_cutoff = 7400;
for (size_t i = 0; i < 100; i++) {
BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2;
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq),
freq / warp_factor);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
low_freq),
low_freq);
AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
high_freq),
high_freq);
BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(),
freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2
BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq2);
BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, warp_factor,
freq3);
KALDI_ASSERT(w3 >= w2); // increasing function.
BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff,
low_freq, high_freq, 1.0,
freq3);
AssertEqual(w3dash, freq3);
}
}
static void UnitTestFeat() {
UnitTestVtln();
UnitTestReadWave();
UnitTestSimple();
UnitTestHTKCompare1();
UnitTestHTKCompare2();
// commenting out this one as it doesn't compare right now I normalized
// the way the FFT bins are treated (removed offset of 0.5)... this seems
// to relate to the way frequency zero behaves.
UnitTestHTKCompare3();
UnitTestHTKCompare4();
UnitTestHTKCompare5();
UnitTestHTKCompare6();
std::cout << "Tests succeeded.\n";
}
int main() {
try {
for (int i = 0; i < 5; i++)
UnitTestFeat();
std::cout << "Tests succeeded.\n";
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return 1;
}
}

@ -1,125 +0,0 @@
// todo refactor, repalce with gtest
#include "frontend/linear_spectrogram.h"
#include "frontend/normalizer.h"
#include "frontend/feature_extractor_interface.h"
#include "kaldi/util/table-types.h"
#include "base/log.h"
#include "base/flags.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
DEFINE_string(wav_rspecifier, "", "test wav path");
DEFINE_string(feature_wspecifier, "", "test wav ark");
DEFINE_string(feature_check_wspecifier, "", "test wav ark");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark");
std::vector<float> mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007};
std::vector<float> variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804};
int count_ = 912592;
void WriteMatrix() {
kaldi::Matrix<double> cmvn_stats(2, mean_.size()+ 1);
for (size_t idx = 0; idx < mean_.size(); ++idx) {
cmvn_stats(0, idx) = mean_[idx];
cmvn_stats(1, idx) = variance_[idx];
}
cmvn_stats(0, mean_.size()) = count_;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true);
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier);
WriteMatrix();
// test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn
int32 num_done = 0, num_err = 0;
ppspeech::LinearSpectrogramOptions opt;
opt.frame_opts.frame_length_ms = 20;
opt.frame_opts.frame_shift_ms = 10;
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> base_feature_extractor(
new ppspeech::DecibelNormalizer(db_norm_opt));
ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor));
ppspeech::CMVN cmvn(FLAGS_cmvn_write_path);
float streaming_chunk = 0.36;
int sample_rate = 16000;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << mean_.size();
for (size_t i = 0; i < mean_.size(); i++) {
mean_[i] /= count_;
variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i];
if (variance_[i] < 1.0e-20) {
variance_[i] = 1.0e-20;
}
variance_[i] = 1.0 / std::sqrt(variance_[i]);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData &wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), this_channel);
int tot_samples = waveform.Dim();
int sample_offset = 0;
std::vector<kaldi::Matrix<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
kaldi::Matrix<BaseFloat> features;
linear_spectrogram.AcceptWaveform(wav_chunk);
linear_spectrogram.ReadFeats(&features);
feats.push_back(features);
sample_offset += cur_chunk_size;
feature_rows += features.NumRows();
}
int cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx];
}
++cur_idx;
}
}
feat_writer.Write(utt, features);
cur_idx = 0;
kaldi::Matrix<kaldi::BaseFloat> features_check(feature_rows, feats[0].NumCols());
for (auto feat : feats) {
for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) {
for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) {
features_check(cur_idx, col_idx) = feat(row_idx, col_idx);
}
kaldi::SubVector<BaseFloat> row_feat(features_check, cur_idx);
cmvn.ApplyCMVN(true, &row_feat);
++cur_idx;
}
}
feat_cmvn_check_writer.Write(utt, features_check);
if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
num_done++;
}
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -1,134 +0,0 @@
#include "paddle_inference_api.h"
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <numeric>
#include <functional>
void produce_data(std::vector<std::vector<float>>* data);
void model_forward_test();
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
model_forward_test();
return 0;
}
void model_forward_test() {
std::cout << "1. read the data" << std::endl;
std::vector<std::vector<float>> feats;
produce_data(&feats);
std::cout << "2. load the model" << std::endl;;
std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel";
std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams";
paddle_infer::Config config;
config.SetModel(model_graph, model_params);
config.SwitchIrOptim(false);
config.DisableFCPadding();
auto predictor = paddle_infer::CreatePredictor(config);
std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl;
std::vector<float> paddle_input_feature_matrix;
for(const auto& item : feats) {
paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end());
}
std::cout << "4. fead the data to model" << std::endl;
int row = feats.size();
int col = feats[0].size();
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(paddle_input_feature_matrix.data());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_h_box = predictor->GetInputHandle(input_names[2]);
std::vector<int> chunk_state_h_box_shape = {3, 1, 1024};
chunk_state_h_box->Reshape(chunk_state_h_box_shape);
int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_h_box_data(chunk_state_h_box_size, 0.0f);
chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data());
std::unique_ptr<paddle_infer::Tensor> chunk_state_c_box = predictor->GetInputHandle(input_names[3]);
std::vector<int> chunk_state_c_box_shape = {3, 1, 1024};
chunk_state_c_box->Reshape(chunk_state_c_box_shape);
int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(),
1, std::multiplies<int>());
std::vector<float> chunk_state_c_box_data(chunk_state_c_box_size, 0.0f);
chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data());
bool success = predictor->Run();
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
std::vector<int> h_out_shape = h_out->shape();
int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> h_out_data(h_out_size);
h_out->CopyToCpu(h_out_data.data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
std::vector<int> c_out_shape = c_out->shape();
int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(),
1, std::multiplies<int>());
std::vector<float> c_out_data(c_out_size);
c_out->CopyToCpu(c_out_data.data());
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
std::vector<float> output_probs;
int output_size = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
output_probs.resize(output_size);
output_tensor->CopyToCpu(output_probs.data());
row = output_shape[1];
col = output_shape[2];
std::vector<std::vector<float>> probs;
probs.reserve(row);
for (int i = 0; i < row; i++) {
probs.push_back(std::vector<float>());
probs.back().reserve(col);
for (int j = 0; j < col; j++) {
probs.back().push_back(output_probs[i * col + j]);
}
}
std::vector<std::vector<float>> log_feat = probs;
std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl;
for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) {
for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) {
std::cout << log_feat[row_idx][col_idx] << " ";
}
std::cout << std::endl;
}
}
void produce_data(std::vector<std::vector<float>>* data) {
int chunk_size = 35;
int col_size = 161;
data->reserve(chunk_size);
data->back().reserve(col_size);
for (int row = 0; row < chunk_size; ++row) {
data->push_back(std::vector<float>());
for (int col_idx = 0; col_idx < col_size; ++col_idx) {
data->back().push_back(0.201);
}
}
}

@ -1,7 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/basic_types.h" #include "base/basic_types.h"
struct DecoderResult { struct DecoderResult {
BaseFloat acoustic_score; BaseFloat acoustic_score;
std::vector<int32> words_idx; std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp; std::vector<pair<int32, int32>> time_stamp;
}; };

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/basic_types.h"
@ -9,292 +23,290 @@ namespace ppspeech {
using std::vector; using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_(opts), : opts_(opts),
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id(-1), blank_id(-1),
space_id(-1), space_id(-1),
num_frame_decoded_(0), num_frame_decoded_(0),
root(nullptr) { root(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
} }
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size(); LOG(INFO) << "read the vocabulary success, dict size: "
<< vocabulary_.size();
LOG(INFO) << "language model path: " << opts_.lm_path; LOG(INFO) << "language model path: " << opts_.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha, init_ext_scorer_ = std::make_shared<Scorer>(
opts_.beta, opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
opts_.lm_path,
vocabulary_);
} }
void CTCBeamSearch::Reset() { void CTCBeamSearch::Reset() {
num_frame_decoded_ = 0; num_frame_decoded_ = 0;
ResetPrefixes(); ResetPrefixes();
} }
void CTCBeamSearch::InitDecoder() { void CTCBeamSearch::InitDecoder() {
blank_id = 0; blank_id = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id = it - vocabulary_.begin(); space_id = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id >= vocabulary_.size()) { if ((size_t)space_id >= vocabulary_.size()) {
space_id = -2; space_id = -2;
} }
ResetPrefixes(); ResetPrefixes();
root = std::make_shared<PathTrie>(); root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0; root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get()); prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
!init_ext_scorer_->is_character_based()) {
auto fst_dict = auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary); static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher); root->set_matcher(matcher);
} }
} }
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) { void CTCBeamSearch::Decode(
return; std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
return num_frame_decoded_;
}
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void CTCBeamSearch::AdvanceDecode(
int max_frames) { const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames) {
while (max_frames > 0) { while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
break; break;
} }
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); likelihood.push_back(
AdvanceDecoding(likelihood); decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
max_frames--; AdvanceDecoding(likelihood);
max_frames--;
} }
} }
void CTCBeamSearch::ResetPrefixes() { void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) { for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) { if (prefixes[i] != nullptr) {
delete prefixes[i]; delete prefixes[i];
prefixes[i] = nullptr; prefixes[i] = nullptr;
}
} }
}
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset(); timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f; LOG(INFO) << "ctc decoding elapsed time(s) "
return 0; << static_cast<float>(timer.Elapsed()) / 1000.0f;
} return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
std::vector<std::pair<double, std::string>> result; std::vector<std::pair<double, std::string>> result;
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
return result[0].second; return result[0].second;
} }
string CTCBeamSearch::GetFinalBestPath() { string CTCBeamSearch::GetFinalBestPath() {
CalculateApproxScore(); CalculateApproxScore();
LMRescore(); LMRescore();
return GetBestPath(); return GetBestPath();
} }
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) { void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
size_t num_time_steps = probs.size(); size_t num_time_steps = probs.size();
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
double cutoff_prob = opts_.cutoff_prob; double cutoff_prob = opts_.cutoff_prob;
size_t cutoff_top_n = opts_.cutoff_top_n; size_t cutoff_top_n = opts_.cutoff_top_n;
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0)); vector<vector<double>> probs_seq(probs.size(),
vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size(); int row = probs.size();
for(int i = 0; i < row; i++) { int col = probs[0].size();
for (int j = 0; j < col; j++){ for (int i = 0; i < row; i++) {
probs_seq[i][j] = static_cast<double>(probs[i][j]); for (int j = 0; j < col; j++) {
} probs_seq[i][j] = static_cast<double>(probs[i][j]);
} }
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
} }
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
}
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam, for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const std::pair<size_t, BaseFloat>& log_prob_idx, const auto& prob = probs_seq[time_step];
const BaseFloat& min_cutoff) {
size_t beam_size = opts_.beam_size; float min_cutoff = -NUM_FLT_INF;
const auto& c = log_prob_idx.first; bool full_beam = false;
const auto& log_prob_c = log_prob_idx.second; if (init_ext_scorer_ != nullptr) {
size_t prefixes_len = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
for (size_t i = 0; i < prefixes_len; ++i) { prefixes.begin() + num_prefixes,
auto prefix = prefixes[i]; prefix_compare);
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break; if (num_prefixes == 0) {
} continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
if (c == blank_id) { vector<std::pair<size_t, float>> log_prob_idx =
prefix->log_prob_b_cur = log_sum_exp( get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
// repeated character // loop over chars
if (c == prefix->character) { size_t log_prob_idx_len = log_prob_idx.size();
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) for (size_t index = 0; index < log_prob_idx_len; index++) {
prefix->log_prob_nb_cur = log_sum_exp( SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
prefix->log_prob_nb_cur, }
log_prob_c +
prefix->log_prob_nb_prev); prefixes.clear();
}
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
num_frame_decoded_++;
} // for probs_seq
}
// get new prefix int32 CTCBeamSearch::SearchOneChar(
auto prefix_new = prefix->get_path_trie(c); const bool& full_beam,
if (prefix_new != nullptr) { const std::pair<size_t, BaseFloat>& log_prob_idx,
float log_p = -NUM_FLT_INF; const BaseFloat& min_cutoff) {
if (c == prefix->character && size_t beam_size = opts_.beam_size;
prefix->log_prob_b_prev > -NUM_FLT_INF) { const auto& c = log_prob_idx.first;
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) const auto& log_prob_c = log_prob_idx.second;
log_p = log_prob_c + prefix->log_prob_b_prev; size_t prefixes_len = std::min(prefixes.size(), beam_size);
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) for (size_t i = 0; i < prefixes_len; ++i) {
log_p = log_prob_c + prefix->score; auto prefix = prefixes[i];
} if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
} }
float score = 0.0; if (c == blank_id) {
vector<string> ngram; prefix->log_prob_b_cur =
ngram = init_ext_scorer_->make_ngram(prefix_to_score); log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
// lm score: p_{lm}(W)^{\alpha} + \beta continue;
score = init_ext_scorer_->get_log_cond_prob(ngram) * }
init_ext_scorer_->alpha;
log_p += score; // repeated character
log_p += init_ext_scorer_->beta; if (c == prefix->character) {
} // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
// p_{nb}(l;x_{1:t}) prefix->log_prob_nb_cur = log_sum_exp(
prefix_new->log_prob_nb_cur = prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
log_sum_exp(prefix_new->log_prob_nb_cur, }
log_p);
} // get new prefix
} // end of loop over prefix auto prefix_new = prefix->get_path_trie(c);
return 0; if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie* prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
return 0;
} }
void CTCBeamSearch::CalculateApproxScore() { void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(
prefixes.begin(), prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
prefixes.begin() + num_prefixes,
prefix_compare); // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
// compute aproximate ctc score as the return score, without affecting the for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
// return order of decoding result. To delete when decoder gets stable. double approx_ctc = prefixes[i]->score;
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { if (init_ext_scorer_ != nullptr) {
double approx_ctc = prefixes[i]->score; vector<int> output;
if (init_ext_scorer_ != nullptr) { prefixes[i]->get_path_vec(output);
vector<int> output; auto prefix_length = output.size();
prefixes[i]->get_path_vec(output); auto words = init_ext_scorer_->split_labels(output);
auto prefix_length = output.size(); // remove word insert
auto words = init_ext_scorer_->split_labels(output); approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove word insert // remove language model weight:
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
// remove language model weight: init_ext_scorer_->alpha;
approx_ctc -= }
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; prefixes[i]->approx_ctc = approx_ctc;
} }
prefixes[i]->approx_ctc = approx_ctc;
}
} }
void CTCBeamSearch::LMRescore() { void CTCBeamSearch::LMRescore() {
size_t beam_size = opts_.beam_size; size_t beam_size = opts_.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { if (init_ext_scorer_ != nullptr &&
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { !init_ext_scorer_->is_character_based()) {
auto prefix = prefixes[i]; for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
if (!prefix->is_empty() && prefix->character != space_id) { auto prefix = prefixes[i];
float score = 0.0; if (!prefix->is_empty() && prefix->character != space_id) {
vector<string> ngram = init_ext_scorer_->make_ngram(prefix); float score = 0.0;
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha; vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score += init_ext_scorer_->beta; score = init_ext_scorer_->get_log_cond_prob(ngram) *
prefix->score += score; init_ext_scorer_->alpha;
} score += init_ext_scorer_->beta;
prefix->score += score;
}
}
} }
}
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h" #include "nnet/decodable-itf.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/ctc_decoders/path_trie.h"
#pragma once #pragma once
@ -17,63 +31,66 @@ struct CTCBeamSearchOptions {
int beam_size; int beam_size;
int cutoff_top_n; int cutoff_top_n;
int num_proc_bsearch; int num_proc_bsearch;
CTCBeamSearchOptions() : CTCBeamSearchOptions()
dict_file("./model/words.txt"), : dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"), lm_path("./model/lm.arpa"),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300), beam_size(300),
cutoff_prob(0.99f), cutoff_prob(0.99f),
cutoff_top_n(40), cutoff_top_n(40),
num_proc_bsearch(0) { num_proc_bsearch(0) {}
}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file "); opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file"); opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha"); opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta"); opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method"); opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
} }
}; };
class CTCBeamSearch { class CTCBeamSearch {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {} ~CTCBeamSearch() {}
void InitDecoder(); void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable); void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, void AdvanceDecode(
int max_frames); const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset(); void Reset();
private:
void ResetPrefixes(); private:
int32 SearchOneChar(const bool& full_beam, void ResetPrefixes();
const std::pair<size_t, BaseFloat>& log_prob_idx, int32 SearchOneChar(const bool& full_beam,
const BaseFloat& min_cutoff); const std::pair<size_t, BaseFloat>& log_prob_idx,
void CalculateApproxScore(); const BaseFloat& min_cutoff);
void LMRescore(); void CalculateApproxScore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs); void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later CTCBeamSearchOptions opts_;
//std::vector<DecodeResult> decoder_results_; std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later // std::vector<DecodeResult> decoder_results_;
size_t blank_id; std::vector<std::string> vocabulary_; // todo remove later
int space_id; size_t blank_id;
std::shared_ptr<PathTrie> root; int space_id;
std::vector<PathTrie*> prefixes; std::shared_ptr<PathTrie> root;
int num_frame_decoded_; std::vector<PathTrie*> prefixes;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace basr

@ -22,15 +22,16 @@ namespace ppspeech {
class FbankExtractor : FeatureExtractorInterface { class FbankExtractor : FeatureExtractorInterface {
public: public:
explicit FbankExtractor(const FbankOptions& opts, explicit FbankExtractor(const FbankOptions& opts,
share_ptr<FeatureExtractorInterface> pre_extractor); share_ptr<FeatureExtractorInterface> pre_extractor);
virtual void AcceptWaveform(const kaldi::Vector<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::Vector<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::Vector<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
private: private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave, bool Compute(const kaldi::Vector<kaldi::BaseFloat>& wave,
kaldi::Vector<kaldi::BaseFloat>* feat) const; kaldi::Vector<kaldi::BaseFloat>* feat) const;
}; };
} // namespace ppspeech } // namespace ppspeech

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,14 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -21,7 +21,8 @@ namespace ppspeech {
class FeatureExtractorInterface { class FeatureExtractorInterface {
public: public:
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0; virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input) = 0;
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0; virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat) = 0;
virtual size_t Dim() const = 0; virtual size_t Dim() const = 0;
}; };

@ -25,97 +25,97 @@ using kaldi::VectorBase;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
//todo remove later // todo remove later
void CopyVector2StdVector_(const VectorBase<BaseFloat>& input, void CopyVector2StdVector_(const VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;
output->resize(input.Dim()); output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) { for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx); (*output)[idx] = input(idx);
} }
} }
void CopyStdVector2Vector_(const vector<BaseFloat>& input, void CopyStdVector2Vector_(const vector<BaseFloat>& input,
Vector<BaseFloat>* output) { Vector<BaseFloat>* output) {
if (input.empty()) return; if (input.empty()) return;
output->Resize(input.size()); output->Resize(input.size());
for (size_t idx = 0; idx < input.size(); ++idx) { for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx]; (*output)(idx) = input[idx];
} }
} }
LinearSpectrogram::LinearSpectrogram( LinearSpectrogram::LinearSpectrogram(
const LinearSpectrogramOptions& opts, const LinearSpectrogramOptions& opts,
std::unique_ptr<FeatureExtractorInterface> base_extractor) { std::unique_ptr<FeatureExtractorInterface> base_extractor) {
opts_ = opts; opts_ = opts;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
int32 window_size = opts.frame_opts.WindowSize(); int32 window_size = opts.frame_opts.WindowSize();
int32 window_shift = opts.frame_opts.WindowShift(); int32 window_shift = opts.frame_opts.WindowShift();
fft_points_ = window_size; fft_points_ = window_size;
hanning_window_.resize(window_size); hanning_window_.resize(window_size);
double a = M_2PI / (window_size - 1); double a = M_2PI / (window_size - 1);
hanning_window_energy_ = 0; hanning_window_energy_ = 0;
for (int i = 0; i < window_size; ++i) { for (int i = 0; i < window_size; ++i) {
hanning_window_[i] = 0.5 - 0.5 * cos(a * i); hanning_window_[i] = 0.5 - 0.5 * cos(a * i);
hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; hanning_window_energy_ += hanning_window_[i] * hanning_window_[i];
} }
dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz
} }
void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) { void LinearSpectrogram::AcceptWaveform(const VectorBase<BaseFloat>& input) {
base_extractor_->AcceptWaveform(input); base_extractor_->AcceptWaveform(input);
} }
void LinearSpectrogram::Hanning(vector<float>* data) const { void LinearSpectrogram::Hanning(vector<float>* data) const {
CHECK_GE(data->size(), hanning_window_.size()); CHECK_GE(data->size(), hanning_window_.size());
for (size_t i = 0; i < hanning_window_.size(); ++i) { for (size_t i = 0; i < hanning_window_.size(); ++i) {
data->at(i) *= hanning_window_[i]; data->at(i) *= hanning_window_[i];
} }
} }
bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v, bool LinearSpectrogram::NumpyFft(vector<BaseFloat>* v,
vector<BaseFloat>* real, vector<BaseFloat>* real,
vector<BaseFloat>* img) const { vector<BaseFloat>* img) const {
Vector<BaseFloat> v_tmp; Vector<BaseFloat> v_tmp;
CopyStdVector2Vector_(*v, &v_tmp); CopyStdVector2Vector_(*v, &v_tmp);
RealFft(&v_tmp, true); RealFft(&v_tmp, true);
CopyVector2StdVector_(v_tmp, v); CopyVector2StdVector_(v_tmp, v);
real->push_back(v->at(0)); real->push_back(v->at(0));
img->push_back(0); img->push_back(0);
for (int i = 1; i < v->size() / 2; i++) { for (int i = 1; i < v->size() / 2; i++) {
real->push_back(v->at(2 * i)); real->push_back(v->at(2 * i));
img->push_back(v->at(2 * i + 1)); img->push_back(v->at(2 * i + 1));
} }
real->push_back(v->at(1)); real->push_back(v->at(1));
img->push_back(0); img->push_back(0);
return true; return true;
} }
// todo remove later // todo remove later
void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) { void LinearSpectrogram::ReadFeats(Matrix<BaseFloat>* feats) {
Vector<BaseFloat> tmp; Vector<BaseFloat> tmp;
waveform_.Resize(base_extractor_->Dim()); waveform_.Resize(base_extractor_->Dim());
Compute(tmp, &waveform_); Compute(tmp, &waveform_);
vector<vector<BaseFloat>> result; vector<vector<BaseFloat>> result;
vector<BaseFloat> feats_vec; vector<BaseFloat> feats_vec;
CopyVector2StdVector_(waveform_, &feats_vec); CopyVector2StdVector_(waveform_, &feats_vec);
Compute(feats_vec, result); Compute(feats_vec, result);
feats->Resize(result.size(), result[0].size()); feats->Resize(result.size(), result[0].size());
for (int row_idx = 0; row_idx < result.size(); ++row_idx) { for (int row_idx = 0; row_idx < result.size(); ++row_idx) {
for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) { for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) {
(*feats)(row_idx, col_idx) = result[row_idx][col_idx]; (*feats)(row_idx, col_idx) = result[row_idx][col_idx];
}
} }
} waveform_.Resize(0);
waveform_.Resize(0);
} }
void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) { void LinearSpectrogram::Read(VectorBase<BaseFloat>* feat) {
// todo // todo
return; return;
} }
// only for test, remove later // only for test, remove later
@ -129,49 +129,49 @@ void LinearSpectrogram::Compute(const VectorBase<kaldi::BaseFloat>& input,
// todo: refactor later (SmileGoat) // todo: refactor later (SmileGoat)
bool LinearSpectrogram::Compute(const vector<float>& wave, bool LinearSpectrogram::Compute(const vector<float>& wave,
vector<vector<float>>& feat) { vector<vector<float>>& feat) {
int num_samples = wave.size(); int num_samples = wave.size();
const int& frame_length = opts_.frame_opts.WindowSize(); const int& frame_length = opts_.frame_opts.WindowSize();
const int& sample_rate = opts_.frame_opts.samp_freq; const int& sample_rate = opts_.frame_opts.samp_freq;
const int& frame_shift = opts_.frame_opts.WindowShift(); const int& frame_shift = opts_.frame_opts.WindowShift();
const int& fft_points = fft_points_; const int& fft_points = fft_points_;
const float scale = hanning_window_energy_ * sample_rate; const float scale = hanning_window_energy_ * sample_rate;
if (num_samples < frame_length) { if (num_samples < frame_length) {
return true; return true;
} }
int num_frames = 1 + ((num_samples - frame_length) / frame_shift); int num_frames = 1 + ((num_samples - frame_length) / frame_shift);
feat.resize(num_frames); feat.resize(num_frames);
vector<float> fft_real((fft_points_ / 2 + 1), 0); vector<float> fft_real((fft_points_ / 2 + 1), 0);
vector<float> fft_img((fft_points_ / 2 + 1), 0); vector<float> fft_img((fft_points_ / 2 + 1), 0);
vector<float> v(frame_length, 0); vector<float> v(frame_length, 0);
vector<float> power((fft_points / 2 + 1)); vector<float> power((fft_points / 2 + 1));
for (int i = 0; i < num_frames; ++i) { for (int i = 0; i < num_frames; ++i) {
vector<float> data(wave.data() + i * frame_shift, vector<float> data(wave.data() + i * frame_shift,
wave.data() + i * frame_shift + frame_length); wave.data() + i * frame_shift + frame_length);
Hanning(&data); Hanning(&data);
fft_img.clear(); fft_img.clear();
fft_real.clear(); fft_real.clear();
v.assign(data.begin(), data.end()); v.assign(data.begin(), data.end());
NumpyFft(&v, &fft_real, &fft_img); NumpyFft(&v, &fft_real, &fft_img);
feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz
for (int j = 0; j < (fft_points / 2 + 1); ++j) { for (int j = 0; j < (fft_points / 2 + 1); ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
feat[i][j] = power[j]; feat[i][j] = power[j];
if (j == 0 || j == feat[0].size() - 1) { if (j == 0 || j == feat[0].size() - 1) {
feat[i][j] /= scale; feat[i][j] /= scale;
} else { } else {
feat[i][j] *= (2.0 / scale); feat[i][j] *= (2.0 / scale);
} }
// log added eps=1e-14 // log added eps=1e-14
feat[i][j] = std::log(feat[i][j] + 1e-14); feat[i][j] = std::log(feat[i][j] + 1e-14);
}
} }
} return true;
return true;
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,32 +1,45 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/feat/feature-window.h" #include "kaldi/feat/feature-window.h"
#include "base/common.h"
namespace ppspeech { namespace ppspeech {
struct LinearSpectrogramOptions { struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts; kaldi::FrameExtractionOptions frame_opts;
LinearSpectrogramOptions(): LinearSpectrogramOptions() : frame_opts() {}
frame_opts() {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); }
frame_opts.Register(opts);
}
}; };
class LinearSpectrogram : public FeatureExtractorInterface { class LinearSpectrogram : public FeatureExtractorInterface {
public: public:
explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, explicit LinearSpectrogram(
std::unique_ptr<FeatureExtractorInterface> base_extractor); const LinearSpectrogramOptions& opts,
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); std::unique_ptr<FeatureExtractorInterface> base_extractor);
virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats); void ReadFeats(kaldi::Matrix<kaldi::BaseFloat>* feats);
private: private:
void Hanning(std::vector<kaldi::BaseFloat>* data) const; void Hanning(std::vector<kaldi::BaseFloat>* data) const;
bool Compute(const std::vector<kaldi::BaseFloat>& wave, bool Compute(const std::vector<kaldi::BaseFloat>& wave,
std::vector<std::vector<kaldi::BaseFloat>>& feat); std::vector<std::vector<kaldi::BaseFloat>>& feat);
@ -41,7 +54,7 @@ class LinearSpectrogram : public FeatureExtractorInterface {
std::vector<kaldi::BaseFloat> hanning_window_; std::vector<kaldi::BaseFloat> hanning_window_;
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat) kaldi::Vector<kaldi::BaseFloat> waveform_; // remove later, todo(SmileGoat)
std::unique_ptr<FeatureExtractorInterface> base_extractor_; std::unique_ptr<FeatureExtractorInterface> base_extractor_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/normalizer.h" #include "frontend/normalizer.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
@ -12,169 +26,173 @@ using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) {
opts_ = opts; opts_ = opts;
dim_ = 0; dim_ = 0;
} }
void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase<BaseFloat>& input) { void DecibelNormalizer::AcceptWaveform(
dim_ = input.Dim(); const kaldi::VectorBase<BaseFloat>& input) {
waveform_.Resize(input.Dim()); dim_ = input.Dim();
waveform_.CopyFromVec(input); waveform_.Resize(input.Dim());
waveform_.CopyFromVec(input);
} }
void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) { void DecibelNormalizer::Read(kaldi::VectorBase<BaseFloat>* feat) {
if (waveform_.Dim() == 0) return; if (waveform_.Dim() == 0) return;
Compute(waveform_, feat); Compute(waveform_, feat);
} }
//todo remove later // todo remove later
void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input, void CopyVector2StdVector(const kaldi::VectorBase<BaseFloat>& input,
vector<BaseFloat>* output) { vector<BaseFloat>* output) {
if (input.Dim() == 0) return; if (input.Dim() == 0) return;
output->resize(input.Dim()); output->resize(input.Dim());
for (size_t idx = 0; idx < input.Dim(); ++idx) { for (size_t idx = 0; idx < input.Dim(); ++idx) {
(*output)[idx] = input(idx); (*output)[idx] = input(idx);
} }
} }
void CopyStdVector2Vector(const vector<BaseFloat>& input, void CopyStdVector2Vector(const vector<BaseFloat>& input,
VectorBase<BaseFloat>* output) { VectorBase<BaseFloat>* output) {
if (input.empty()) return; if (input.empty()) return;
assert(input.size() == output->Dim()); assert(input.size() == output->Dim());
for (size_t idx = 0; idx < input.size(); ++idx) { for (size_t idx = 0; idx < input.size(); ++idx) {
(*output)(idx) = input[idx]; (*output)(idx) = input[idx];
} }
} }
bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input, bool DecibelNormalizer::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const { VectorBase<BaseFloat>* feat) const {
// calculate db rms // calculate db rms
BaseFloat rms_db = 0.0; BaseFloat rms_db = 0.0;
BaseFloat mean_square = 0.0; BaseFloat mean_square = 0.0;
BaseFloat gain = 0.0; BaseFloat gain = 0.0;
BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1));
vector<BaseFloat> samples; vector<BaseFloat> samples;
samples.resize(input.Dim()); samples.resize(input.Dim());
for (int32 i = 0; i < samples.size(); ++i) { for (int32 i = 0; i < samples.size(); ++i) {
samples[i] = input(i); samples[i] = input(i);
}
// square
for (auto &d : samples) {
if (opts_.convert_int_float) {
d = d * wave_float_normlization;
} }
mean_square += d * d;
} // square
for (auto& d : samples) {
// mean if (opts_.convert_int_float) {
mean_square /= samples.size(); d = d * wave_float_normlization;
rms_db = 10 * std::log10(mean_square); }
gain = opts_.target_db - rms_db; mean_square += d * d;
}
if (gain > opts_.max_gain_db) {
LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB," // mean
<< "because the the probable gain have exceeds opts_.max_gain_db" mean_square /= samples.size();
<< opts_.max_gain_db << "dB."; rms_db = 10 * std::log10(mean_square);
return false; gain = opts_.target_db - rms_db;
}
if (gain > opts_.max_gain_db) {
// Note that this is an in-place transformation. LOG(ERROR)
for (auto &item : samples) { << "Unable to normalize segment to " << opts_.target_db << "dB,"
// python item *= 10.0 ** (gain / 20.0) << "because the the probable gain have exceeds opts_.max_gain_db"
item *= std::pow(10.0, gain / 20.0); << opts_.max_gain_db << "dB.";
} return false;
}
CopyStdVector2Vector(samples, feat);
return true; // Note that this is an in-place transformation.
for (auto& item : samples) {
// python item *= 10.0 ** (gain / 20.0)
item *= std::pow(10.0, gain / 20.0);
}
CopyStdVector2Vector(samples, feat);
return true;
} }
CMVN::CMVN(std::string cmvn_file) : var_norm_(true) { CMVN::CMVN(std::string cmvn_file) : var_norm_(true) {
bool binary; bool binary;
kaldi::Input ki(cmvn_file, &binary); kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary); stats_.Read(ki.Stream(), binary);
} }
void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) { void CMVN::AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input) {
return; return;
} }
void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { void CMVN::Read(kaldi::VectorBase<BaseFloat>* feat) { return; }
return;
}
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) { void CMVN::ApplyCMVN(bool var_norm, VectorBase<BaseFloat>* feats) {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1; int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) { if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
KALDI_ERR << "Dim mismatch: cmvn " feats->Dim() % dim != 0) {
<< stats_.NumRows() << 'x' << stats_.NumCols() KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x'
<< ", feats " << feats->Dim() << 'x'; << stats_.NumCols() << ", feats " << feats->Dim() << 'x';
}
if (stats_.NumRows() == 1 && var_norm) {
KALDI_ERR << "You requested variance normalization but no variance stats_ "
<< "are supplied.";
}
double count = stats_(0, dim);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0)
KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: "
<< "count = " << count;
if (!var_norm) {
Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
//fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature.
//the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim*idx, dim);
stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); if (stats_.NumRows() == 1 && var_norm) {
feats->AddVec(1.0, offset); KALDI_ERR
return; << "You requested variance normalization but no variance stats_ "
} << "are supplied.";
// norm(0, d) = mean offset; }
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim()); double count = stats_(0, dim);
for (int32 d = 0; d < dim; d++) { // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
double mean, offset, scale; // computing an offset and representing it as stats_, we use a count of one.
mean = stats_(0, d)/count; if (count < 1.0)
double var = (stats_(1, d)/count) - mean*mean, KALDI_ERR << "Insufficient stats_ for cepstral mean and variance "
floor = 1.0e-20; "normalization: "
if (var < floor) { << "count = " << count;
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor; if (!var_norm) {
var = floor; Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim);
Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal
// with the dim of feature.
// the dim of feats = dim * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx,
dim);
stats_tmp.CopyFromVec(mean_stats);
}
offset.AddVec(-1.0 / count, mean_stats_apply);
feats->AddVec(1.0, offset);
return;
} }
scale = 1.0 / sqrt(var); // norm(0, d) = mean offset;
if (scale != scale || 1/scale == 0.0) // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
offset = -(mean*scale); for (int32 d = 0; d < dim; d++) {
for (int32 d_skip = d; d_skip < feats->Dim();) { double mean, offset, scale;
norm(0, d_skip) = offset; mean = stats_(0, d) / count;
norm(1, d_skip) = scale; double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
d_skip = d_skip + dim; if (var < floor) {
KALDI_WARN << "Flooring cepstral variance from " << var << " to "
<< floor;
var = floor;
}
scale = 1.0 / sqrt(var);
if (scale != scale || 1 / scale == 0.0)
KALDI_ERR
<< "NaN or infinity in cepstral mean/variance computation";
offset = -(mean * scale);
for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset;
norm(1, d_skip) = scale;
d_skip = d_skip + dim;
}
} }
} // Apply the normalization.
// Apply the normalization. feats->MulElements(norm.Row(1));
feats->MulElements(norm.Row(1)); feats->AddVec(1.0, norm.Row(0));
feats->AddVec(1.0, norm.Row(0));
} }
void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) { void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats) {
ApplyCmvn(stats_, var_norm, feats); ApplyCmvn(stats_, var_norm, feats);
} }
bool CMVN::Compute(const VectorBase<BaseFloat>& input, bool CMVN::Compute(const VectorBase<BaseFloat>& input,
VectorBase<BaseFloat>* feat) const { VectorBase<BaseFloat>* feat) const {
return false; return false;
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,37 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/util/options-itf.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/options-itf.h"
namespace ppspeech { namespace ppspeech {
struct DecibelNormalizerOptions { struct DecibelNormalizerOptions {
float target_db; float target_db;
float max_gain_db; float max_gain_db;
bool convert_int_float; bool convert_int_float;
DecibelNormalizerOptions() : DecibelNormalizerOptions()
target_db(-20), : target_db(-20), max_gain_db(300.0), convert_int_float(false) {}
max_gain_db(300.0),
convert_int_float(false) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("target-db", &target_db, "target db for db normalization"); opts->Register(
opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization"); "target-db", &target_db, "target db for db normalization");
opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float"); opts->Register(
"max-gain-db", &max_gain_db, "max gain db for db normalization");
opts->Register("convert-int-float",
&convert_int_float,
"if convert int samples to float");
} }
}; };
class DecibelNormalizer : public FeatureExtractorInterface { class DecibelNormalizer : public FeatureExtractorInterface {
public: public:
explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); explicit DecibelNormalizer(const DecibelNormalizerOptions& opts);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
kaldi::VectorBase<kaldi::BaseFloat>* feat) const; kaldi::VectorBase<kaldi::BaseFloat>* feat) const;
private: private:
DecibelNormalizerOptions opts_; DecibelNormalizerOptions opts_;
size_t dim_; size_t dim_;
@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface {
class CMVN : public FeatureExtractorInterface { class CMVN : public FeatureExtractorInterface {
public: public:
explicit CMVN(std::string cmvn_file); explicit CMVN(std::string cmvn_file);
virtual void AcceptWaveform(const kaldi::VectorBase<kaldi::BaseFloat>& input); virtual void AcceptWaveform(
const kaldi::VectorBase<kaldi::BaseFloat>& input);
virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat); virtual void Read(kaldi::VectorBase<kaldi::BaseFloat>* feat);
virtual size_t Dim() const { return stats_.NumCols() - 1; } virtual size_t Dim() const { return stats_.NumCols() - 1; }
bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input, bool Compute(const kaldi::VectorBase<kaldi::BaseFloat>& input,
@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface {
// for test // for test
void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats); void ApplyCMVN(bool var_norm, kaldi::VectorBase<BaseFloat>* feats);
void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats); void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase<BaseFloat>* feats);
private: private:
kaldi::Matrix<double> stats_; kaldi::Matrix<double> stats_;
std::shared_ptr<FeatureExtractorInterface> base_extractor_; std::shared_ptr<FeatureExtractorInterface> base_extractor_;

@ -13,4 +13,3 @@
// limitations under the License. // limitations under the License.
// extract the window of kaldi feat. // extract the window of kaldi feat.

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h // itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University; // Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions. important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/, There are two ways these are used: the old online-decoding code, in
and the new online-decoding code, in ../online2/. In the old online-decoding ../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do: code, the decoder would do:
\code{.cc} \code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode \endcode
and the call to IsLastFrame would block if the features had not arrived yet. and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode(). example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example) control endpointing. In the "new" decoding code, you don't call (for
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this: something like this:
@ -68,7 +87,8 @@ namespace kaldi {
} }
\endcode \endcode
So the decodable object never has IsLastFrame() called. For decoding where So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame. return true for the last frame.
@ -80,43 +100,52 @@ namespace kaldi {
frame of the file once we've decided to terminate decoding. frame of the file once we've decided to terminate decoding.
*/ */
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. /// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// before calling this. /// frame
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; /// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file /// Returns true if this is the last frame. Frames are zero-based, so the
/// is empty (which is a case that I'm not sure all the code will handle, so /// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// be careful). Caution: the behavior of this function in an online setting /// is empty (which is a case that I'm not sure all the code will handle, so
/// is being changed somewhat. In future it may return false in cases where /// be careful). Caution: the behavior of this function in an online
/// we haven't yet decided to terminate decoding, but later true if we decide /// setting
/// to terminate decoding. The plan in future is to rely more on /// is being changed somewhat. In future it may return false in cases where
/// NumFramesReady(), and in future, IsLastFrame() would always return false /// we haven't yet decided to terminate decoding, but later true if we
/// in an online-decoding setting, and would only return true in a /// decide
/// decoding-from-matrix setting where we want to allow the last delta or LDA /// to terminate decoding. The plan in future is to rely more on
/// features to be flushed out for compatibility with the baseline setup. /// NumFramesReady(), and in future, IsLastFrame() would always return false
virtual bool IsLastFrame(int32 frame) const = 0; /// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// The call NumFramesReady() will return the number of frames currently available /// LDA
/// for this decodable object. This is for use in setups where you don't want the /// features to be flushed out for compatibility with the baseline setup.
/// decoder to block while waiting for input. This is newly added as of Jan 2014, virtual bool IsLastFrame(int32 frame) const = 0;
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding. /// The call NumFramesReady() will return the number of frames currently
virtual int32 NumFramesReady() const { /// available
KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; /// for this decodable object. This is for use in setups where you don't
return -1; /// want the
} /// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// Returns the number of states in the acoustic model /// and I hope, going forward, to rely on this mechanism more than
/// (they will be indexed one-based, i.e. from 1 to NumIndices(); /// IsLastFrame to
/// this is for compatibility with OpenFst). /// know when to stop decoding.
virtual int32 NumIndices() const = 0; virtual int32 NumFramesReady() const {
KALDI_ERR
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0; << "NumFramesReady() not implemented for this decodable type.";
return -1;
virtual ~DecodableInterface() {} }
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
}; };
/// @} /// @}
} // namespace Kaldi } // namespace Kaldi

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/decodable.h" #include "nnet/decodable.h"
namespace ppspeech { namespace ppspeech {
@ -5,51 +19,43 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet): Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
frontend_(NULL), : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
nnet_(nnet),
finished_(false),
frames_ready_(0) {
}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
} }
//Decodable::Init(DecodableConfig config) { // Decodable::Init(DecodableConfig config) {
//} //}
bool Decodable::IsLastFrame(int32 frame) const { bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
return finished_ && (frame == frames_ready_ - 1); return finished_ && (frame == frames_ready_ - 1);
} }
int32 Decodable::NumIndices() const { int32 Decodable::NumIndices() const { return 0; }
return 0;
}
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
return 0;
}
void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) { void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
nnet_->FeedForward(features, &nnet_cache_); nnet_->FeedForward(features, &nnet_cache_);
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return ; return;
} }
std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) { std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
std::vector<BaseFloat> result; std::vector<BaseFloat> result;
result.reserve(nnet_cache_.NumCols()); result.reserve(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
result[idx] = nnet_cache_(frame, idx); result[idx] = nnet_cache_(frame, idx);
} }
return result; return result;
} }
void Decodable::Reset() { void Decodable::Reset() {
// frontend_.Reset(); // frontend_.Reset();
nnet_->Reset(); nnet_->Reset();
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,7 +1,21 @@
#include "nnet/decodable-itf.h" // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/decodable-itf.h"
#include "nnet/nnet_interface.h" #include "nnet/nnet_interface.h"
namespace ppspeech { namespace ppspeech {
@ -9,17 +23,20 @@ namespace ppspeech {
struct DecodableOpts; struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet); explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
//void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame); virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later void Acceptlikelihood(
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>& feature); // only for test, todo remove later const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset(); void Reset();
void InputFinished() { finished_ = true; } void InputFinished() { finished_ = true; }
private: private:
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
@ -10,10 +24,9 @@ namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences)= 0; kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
}; };
} // namespace ppspeech } // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
@ -9,43 +23,44 @@ using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
cache_names = absl::StrSplit(opts.cache_names, ","); cache_names = absl::StrSplit(opts.cache_names, ",");
std::vector<std::string> cache_shapes; std::vector<std::string> cache_shapes;
cache_shapes = absl::StrSplit(opts.cache_shape, ","); cache_shapes = absl::StrSplit(opts.cache_shape, ",");
assert(cache_shapes.size() == cache_names.size()); assert(cache_shapes.size() == cache_names.size());
cache_encouts_.clear(); cache_encouts_.clear();
cache_names_idx_.clear(); cache_names_idx_.clear();
for (size_t i = 0; i < cache_shapes.size(); i++) { for (size_t i = 0; i < cache_shapes.size(); i++) {
std::vector<std::string> tmp_shape; std::vector<std::string> tmp_shape;
tmp_shape = absl::StrSplit(cache_shapes[i], "-"); tmp_shape = absl::StrSplit(cache_shapes[i], "-");
std::vector<int> cur_shape; std::vector<int> cur_shape;
std::transform(tmp_shape.begin(), tmp_shape.end(), std::transform(tmp_shape.begin(),
std::back_inserter(cur_shape), tmp_shape.end(),
[](const std::string& s) { std::back_inserter(cur_shape),
return atoi(s.c_str()); [](const std::string& s) { return atoi(s.c_str()); });
}); cache_names_idx_[cache_names[i]] = i;
cache_names_idx_[cache_names[i]] = i; std::shared_ptr<Tensor<BaseFloat>> cache_eout =
std::shared_ptr<Tensor<BaseFloat>> cache_eout = std::make_shared<Tensor<BaseFloat>>(cur_shape); std::make_shared<Tensor<BaseFloat>>(cur_shape);
cache_encouts_.push_back(cache_eout); cache_encouts_.push_back(cache_eout);
} }
} }
PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(opts.model_path, opts.params_path); config.SetModel(opts.model_path, opts.params_path);
if (opts.use_gpu) { if (opts.use_gpu) {
config.EnableUseGpu(500, 0); config.EnableUseGpu(500, 0);
} }
config.SwitchIrOptim(opts.switch_ir_optim); config.SwitchIrOptim(opts.switch_ir_optim);
if (opts.enable_fc_padding == false) { if (opts.enable_fc_padding == false) {
config.DisableFCPadding(); config.DisableFCPadding();
} }
if (opts.enable_profile) { if (opts.enable_profile) {
config.EnableProfile(); config.EnableProfile();
} }
pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num)); pool.reset(
new paddle_infer::services::PredictorPool(config, opts.thread_num));
if (pool == nullptr) { if (pool == nullptr) {
LOG(ERROR) << "create the predictor pool failed"; LOG(ERROR) << "create the predictor pool failed";
} }
@ -59,7 +74,7 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
vector<string> input_names_vec = absl::StrSplit(opts.input_names, ","); vector<string> input_names_vec = absl::StrSplit(opts.input_names, ",");
vector<string> output_names_vec = absl::StrSplit(opts.output_names, ","); vector<string> output_names_vec = absl::StrSplit(opts.output_names, ",");
paddle_infer::Predictor* predictor = GetPredictor(); paddle_infer::Predictor* predictor = GetPredictor();
std::vector<std::string> model_input_names = predictor->GetInputNames(); std::vector<std::string> model_input_names = predictor->GetInputNames();
assert(input_names_vec.size() == model_input_names.size()); assert(input_names_vec.size() == model_input_names.size());
for (size_t i = 0; i < model_input_names.size(); i++) { for (size_t i = 0; i < model_input_names.size(); i++) {
@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) {
std::vector<std::string> model_output_names = predictor->GetOutputNames(); std::vector<std::string> model_output_names = predictor->GetOutputNames();
assert(output_names_vec.size() == model_output_names.size()); assert(output_names_vec.size() == model_output_names.size());
for (size_t i = 0;i < output_names_vec.size(); i++) { for (size_t i = 0; i < output_names_vec.size(); i++) {
assert(output_names_vec[i] == model_output_names[i]); assert(output_names_vec[i] == model_output_names[i]);
} }
ReleasePredictor(predictor); ReleasePredictor(predictor);
InitCacheEncouts(opts); InitCacheEncouts(opts);
} }
void PaddleNnet::Reset() { void PaddleNnet::Reset() { InitCacheEncouts(opts_); }
InitCacheEncouts(opts_);
}
paddle_infer::Predictor* PaddleNnet::GetPredictor() { paddle_infer::Predictor* PaddleNnet::GetPredictor() {
LOG(INFO) << "attempt to get a new predictor instance " << std::endl; LOG(INFO) << "attempt to get a new predictor instance " << std::endl;
@ -122,80 +135,88 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) {
} }
shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) { shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
auto iter = cache_names_idx_.find(name); auto iter = cache_names_idx_.find(name);
if (iter == cache_names_idx_.end()) { if (iter == cache_names_idx_.end()) {
return nullptr; return nullptr;
} }
assert(iter->second < cache_encouts_.size()); assert(iter->second < cache_encouts_.size());
return cache_encouts_[iter->second]; return cache_encouts_[iter->second];
} }
void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features, Matrix<BaseFloat>* inferences) { void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
paddle_infer::Predictor* predictor = GetPredictor(); Matrix<BaseFloat>* inferences) {
int row = features.NumRows(); paddle_infer::Predictor* predictor = GetPredictor();
int col = features.NumCols(); int row = features.NumRows();
std::vector<BaseFloat> feed_feature; int col = features.NumCols();
// todo refactor feed feature: SmileGoat std::vector<BaseFloat> feed_feature;
feed_feature.reserve(row*col); // todo refactor feed feature: SmileGoat
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { feed_feature.reserve(row * col);
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
feed_feature.push_back(features(row_idx, col_idx)); for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len;
audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box =
predictor->GetInputHandle(input_names[2]);
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]);
h_box->Reshape(h_cache->get_shape());
h_box->CopyFromCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_box =
predictor->GetInputHandle(input_names[3]);
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data());
bool success = predictor->Run();
if (success == false) {
LOG(INFO) << "predictor run occurs error";
} }
}
std::vector<std::string> input_names = predictor->GetInputNames(); LOG(INFO) << "get the model success";
std::vector<std::string> output_names = predictor->GetOutputNames(); std::unique_ptr<paddle_infer::Tensor> h_out =
LOG(INFO) << "feat info: row=" << row << ", col= " << col; predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
std::unique_ptr<paddle_infer::Tensor> input_tensor = predictor->GetInputHandle(input_names[0]); h_out->CopyToCpu(h_cache->get_data().data());
std::vector<int> INPUT_SHAPE = {1, row, col}; std::unique_ptr<paddle_infer::Tensor> c_out =
input_tensor->Reshape(INPUT_SHAPE); predictor->GetOutputHandle(output_names[3]);
input_tensor->CopyFromCpu(feed_feature.data()); assert(c_cache->get_shape() == c_out->shape());
std::unique_ptr<paddle_infer::Tensor> input_len = predictor->GetInputHandle(input_names[1]); c_out->CopyToCpu(c_cache->get_data().data());
std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size); // get result
std::vector<int64_t> audio_len; std::unique_ptr<paddle_infer::Tensor> output_tensor =
audio_len.push_back(row); predictor->GetOutputHandle(output_names[0]);
input_len->CopyFromCpu(audio_len.data()); std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
std::unique_ptr<paddle_infer::Tensor> h_box = predictor->GetInputHandle(input_names[2]); col = output_shape[2];
shared_ptr<Tensor<BaseFloat>> h_cache = GetCacheEncoder(input_names[2]); vector<float> inferences_result;
h_box->Reshape(h_cache->get_shape()); inferences->Resize(row, col);
h_box->CopyFromCpu(h_cache->get_data().data()); inferences_result.resize(row * col);
std::unique_ptr<paddle_infer::Tensor> c_box = predictor->GetInputHandle(input_names[3]); output_tensor->CopyToCpu(inferences_result.data());
shared_ptr<Tensor<float>> c_cache = GetCacheEncoder(input_names[3]); ReleasePredictor(predictor);
c_box->Reshape(c_cache->get_shape());
c_box->CopyFromCpu(c_cache->get_data().data()); for (int row_idx = 0; row_idx < row; ++row_idx) {
bool success = predictor->Run(); for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
if (success == false) { inferences_result[col * row_idx + col_idx];
LOG(INFO) << "predictor run occurs error"; }
}
LOG(INFO) << "get the model success";
std::unique_ptr<paddle_infer::Tensor> h_out = predictor->GetOutputHandle(output_names[2]);
assert(h_cache->get_shape() == h_out->shape());
h_out->CopyToCpu(h_cache->get_data().data());
std::unique_ptr<paddle_infer::Tensor> c_out = predictor->GetOutputHandle(output_names[3]);
assert(c_cache->get_shape() == c_out->shape());
c_out->CopyToCpu(c_cache->get_data().data());
// get result
std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
row = output_shape[1];
col = output_shape[2];
vector<float> inferences_result;
inferences->Resize(row, col);
inferences_result.resize(row*col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx];
} }
}
} }
} // namespace ppspeech } // namespace ppspeech

@ -1,8 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "nnet/nnet_interface.h"
#include "base/common.h" #include "base/common.h"
#include "nnet/nnet_interface.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/matrix/kaldi-matrix.h"
@ -13,71 +27,79 @@
namespace ppspeech { namespace ppspeech {
struct ModelOptions { struct ModelOptions {
std::string model_path; std::string model_path;
std::string params_path; std::string params_path;
int thread_num; int thread_num;
bool use_gpu; bool use_gpu;
bool switch_ir_optim; bool switch_ir_optim;
std::string input_names; std::string input_names;
std::string output_names; std::string output_names;
std::string cache_names; std::string cache_names;
std::string cache_shape; std::string cache_shape;
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelOptions() : ModelOptions()
model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"), : model_path(
params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"), "../../../../model/paddle_online_deepspeech/model/"
thread_num(2), "avg_1.jit.pdmodel"),
use_gpu(false), params_path(
input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), "../../../../model/paddle_online_deepspeech/model/"
output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"), "avg_1.jit.pdiparams"),
cache_names("chunk_state_h_box,chunk_state_c_box"), thread_num(2),
cache_shape("3-1-1024,3-1-1024"), use_gpu(false),
switch_ir_optim(false), input_names(
enable_fc_padding(false), "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_"
enable_profile(false) { "box"),
} output_names(
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1"),
cache_names("chunk_state_h_box,chunk_state_c_box"),
cache_shape("3-1-1024,3-1-1024"),
switch_ir_optim(false),
enable_fc_padding(false),
enable_profile(false) {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("model-path", &model_path, "model file path"); opts->Register("model-path", &model_path, "model file path");
opts->Register("model-params", &params_path, "params model file path"); opts->Register("model-params", &params_path, "params model file path");
opts->Register("thread-num", &thread_num, "thread num"); opts->Register("thread-num", &thread_num, "thread num");
opts->Register("use-gpu", &use_gpu, "if use gpu"); opts->Register("use-gpu", &use_gpu, "if use gpu");
opts->Register("input-names", &input_names, "paddle input names"); opts->Register("input-names", &input_names, "paddle input names");
opts->Register("output-names", &output_names, "paddle output names"); opts->Register("output-names", &output_names, "paddle output names");
opts->Register("cache-names", &cache_names, "cache names"); opts->Register("cache-names", &cache_names, "cache names");
opts->Register("cache-shape", &cache_shape, "cache shape"); opts->Register("cache-shape", &cache_shape, "cache shape");
opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option"); opts->Register("switch-ir-optiom",
opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option"); &switch_ir_optim,
opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option"); "paddle SwitchIrOptim option");
} opts->Register("enable-fc-padding",
&enable_fc_padding,
"paddle EnableFCPadding option");
opts->Register(
"enable-profile", &enable_profile, "paddle EnableProfile option");
}
}; };
template<typename T> template <typename T>
class Tensor { class Tensor {
public: public:
Tensor() { Tensor() {}
} Tensor(const std::vector<int>& shape) : _shape(shape) {
Tensor(const std::vector<int>& shape) : int data_size = std::accumulate(
_shape(shape) { _shape.begin(), _shape.end(), 1, std::multiplies<int>());
int data_size = std::accumulate(_shape.begin(), _shape.end(),
1, std::multiplies<int>());
LOG(INFO) << "data size: " << data_size; LOG(INFO) << "data size: " << data_size;
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
void reshape(const std::vector<int>& shape) { void reshape(const std::vector<int>& shape) {
_shape = shape; _shape = shape;
int data_size = std::accumulate(_shape.begin(), _shape.end(), int data_size = std::accumulate(
1, std::multiplies<int>()); _shape.begin(), _shape.end(), 1, std::multiplies<int>());
_data.resize(data_size, 0); _data.resize(data_size, 0);
} }
const std::vector<int>& get_shape() const { const std::vector<int>& get_shape() const { return _shape; }
return _shape; std::vector<T>& get_data() { return _data; }
}
std::vector<T>& get_data() { private:
return _data;
}
private:
std::vector<int> _shape; std::vector<int> _shape;
std::vector<T> _data; std::vector<T> _data;
}; };
@ -85,15 +107,16 @@ private:
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetInterface {
public: public:
PaddleNnet(const ModelOptions& opts); PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences); kaldi::Matrix<kaldi::BaseFloat>* inferences);
virtual void Reset(); virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(const std::string& name); std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name);
void InitCacheEncouts(const ModelOptions& opts); void InitCacheEncouts(const ModelOptions& opts);
private: private:
paddle_infer::Predictor* GetPredictor(); paddle_infer::Predictor* GetPredictor();
int ReleasePredictor(paddle_infer::Predictor* predictor); int ReleasePredictor(paddle_infer::Predictor* predictor);
std::unique_ptr<paddle_infer::services::PredictorPool> pool; std::unique_ptr<paddle_infer::services::PredictorPool> pool;
std::vector<bool> pool_usages; std::vector<bool> pool_usages;
@ -107,4 +130,4 @@ class PaddleNnet : public NnetInterface {
DISALLOW_COPY_AND_ASSIGN(PaddleNnet); DISALLOW_COPY_AND_ASSIGN(PaddleNnet);
}; };
} // namespace ppspeech } // namespace ppspeech

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "utils/file_utils.h" #include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename,
return true; return true;
} }
} }

@ -1,8 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h" #include "base/common.h"
namespace ppspeech { namespace ppspeech {
bool ReadFileToVector(const std::string& filename, bool ReadFileToVector(const std::string& filename,
std::vector<std::string>* data); std::vector<std::string>* data);
} }

Loading…
Cancel
Save