diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7fb01708..09e92a66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,13 +50,13 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ - exclude: (?=speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ + exclude: (?=speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ - id: copyright_checker name: copyright_checker entry: python .pre-commit-hooks/copyright-check.hook language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ - exclude: (?=third_party|pypinyin|speechx/speechx/kaldi).*(\.cpp|\.cc|\.h|\.py)$ + exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch).*(\.cpp|\.cc|\.h|\.py)$ - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index 89752bb9..ac55af12 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py index 2073a78b..1e946adf 100644 --- a/paddlespeech/t2s/modules/transformer/repeat.py +++ b/paddlespeech/t2s/modules/transformer/repeat.py @@ -36,4 +36,4 @@ def repeat(N, fn): Returns: MultiSequential: Repeated model instance. """ - return MultiSequential(*[fn(n) for n in range(N)]) + return MultiSequential(* [fn(n) for n in range(N)]) diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 083e180d..f1330d1d 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -117,6 +117,7 @@ set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") include_directories("${MKLDNN_PATH}/include") set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) set(EXTERNAL_LIB "-lrt -ldl -lpthread") + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} @@ -137,4 +138,7 @@ set(DEPS ${DEPS} #target_link_libraries(lib_name item0 item1) #add_dependencies(lib_name depend-target) -add_subdirectory(speechx) \ No newline at end of file +set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) + +add_subdirectory(speechx) +add_subdirectory(examples) \ No newline at end of file diff --git a/speechx/build.sh b/speechx/build.sh index b3406387..3e9600d5 100755 --- a/speechx/build.sh +++ b/speechx/build.sh @@ -16,7 +16,7 @@ if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifact echo -e "\n" fi -rm -rf build +#rm -rf build mkdir -p build cd build diff --git a/speechx/cmake/external/openblas.cmake b/speechx/cmake/external/openblas.cmake index 14e17195..3c202f7f 100644 --- a/speechx/cmake/external/openblas.cmake +++ b/speechx/cmake/external/openblas.cmake @@ -18,6 +18,8 @@ ExternalProject_Add( SOURCE_DIR ${OpenBLAS_SOURCE_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX= CMAKE_GENERATOR "Unix Makefiles") + + # https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) diff --git a/speechx/examples/.gitkeep b/speechx/examples/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/speechx/examples/CMakeLists.txt b/speechx/examples/CMakeLists.txt new file mode 100644 index 00000000..ef0a72b8 --- /dev/null +++ b/speechx/examples/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_subdirectory(feat) +add_subdirectory(nnet) +add_subdirectory(decoder) diff --git a/speechx/examples/README.md b/speechx/examples/README.md new file mode 100644 index 00000000..fde9a361 --- /dev/null +++ b/speechx/examples/README.md @@ -0,0 +1,5 @@ +# Examples + +* decoder - offline decoder +* feat - mfcc, linear +* nnet - ds2 nn diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt new file mode 100644 index 00000000..cf90d094 --- /dev/null +++ b/speechx/examples/decoder/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc) +target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) \ No newline at end of file diff --git a/speechx/examples/decoder/offline-decoder-main.cc b/speechx/examples/decoder/offline-decoder-main.cc new file mode 100644 index 00000000..8e6e7850 --- /dev/null +++ b/speechx/examples/decoder/offline-decoder-main.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_beam_search_decoder.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_respecifier, "", "test nnet prob"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +// void SplitFeature(kaldi::Matrix feature, +// int32 chunk_size, +// std::vector* feature_chunks) { + +//} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_respecifier); + + // test nnet_output --> decoder result + int32 num_done = 0, num_err = 0; + + ppspeech::CTCBeamSearchOptions opts; + ppspeech::CTCBeamSearch decoder(opts); + + ppspeech::ModelOptions model_opts; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + + std::shared_ptr decodable( + new ppspeech::Decodable(nnet)); + + // int32 chunk_size = 35; + decoder.InitDecoder(); + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + const kaldi::Matrix feature = feature_reader.Value(); + decodable->FeedFeatures(feature); + decoder.AdvanceDecode(decodable, 8); + decodable->InputFinished(); + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable->Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} \ No newline at end of file diff --git a/speechx/examples/feat/CMakeLists.txt b/speechx/examples/feat/CMakeLists.txt new file mode 100644 index 00000000..44738e60 --- /dev/null +++ b/speechx/examples/feat/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + + +add_executable(mfcc-test ${CMAKE_CURRENT_SOURCE_DIR}/feature-mfcc-test.cc) +target_include_directories(mfcc-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(mfcc-test kaldi-mfcc) + +add_executable(linear-spectrogram-main ${CMAKE_CURRENT_SOURCE_DIR}/linear-spectrogram-main.cc) +target_include_directories(linear-spectrogram-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(linear-spectrogram-main frontend kaldi-util kaldi-feat-common gflags glog) \ No newline at end of file diff --git a/speechx/examples/feat/feature-mfcc-test.cc b/speechx/examples/feat/feature-mfcc-test.cc new file mode 100644 index 00000000..ae32aba9 --- /dev/null +++ b/speechx/examples/feat/feature-mfcc-test.cc @@ -0,0 +1,720 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// feat/feature-mfcc-test.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "base/kaldi-math.h" +#include "feat/feature-mfcc.h" +#include "feat/wave-reader.h" +#include "matrix/kaldi-matrix-inl.h" + +using namespace kaldi; + + +static void UnitTestReadWave() { + std::cout << "=== UnitTestReadWave() ===\n"; + + Vector v, v2; + + std::cout << "<<<=== Reading waveform\n"; + + { + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + const Matrix data(wave.Data()); + KALDI_ASSERT(data.NumRows() == 1); + v.Resize(data.NumCols()); + v.CopyFromVec(data.Row(0)); + } + + std::cout + << "<<<=== Reading Vector waveform, prepared by matlab\n"; + std::ifstream input("test_data/test_matlab.ascii"); + KALDI_ASSERT(input.good()); + v2.Read(input, false); + input.close(); + + std::cout + << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n"; + KALDI_ASSERT(v.Dim() == v2.Dim()); + for (int32 i = 0; i < v.Dim(); i++) { + KALDI_ASSERT(v(i) == v2(i)); + } + std::cout << "<<<=== Comparing done\n"; + + // std::cout << "== The Waveform Samples == \n"; + // std::cout << v; + + std::cout << "Test passed :)\n\n"; +} + + +/** + */ +static void UnitTestSimple() { + std::cout << "=== UnitTestSimple() ===\n"; + + Vector v(100000); + Matrix m; + + // init with noise + for (int32 i = 0; i < v.Dim(); i++) { + v(i) = (abs(i * 433024253) % 65535) - (65535 / 2); + } + + std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; + // the parametrization object + MfccOptions op; + // trying to have same opts as baseline. + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "rectangular"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + + Mfcc mfcc(op); + // use default parameters + + // compute mfccs. + mfcc.Compute(v, 1.0, &m); + + // possibly dump + // std::cout << "== Output features == \n" << m; + std::cout << "Test passed :)\n\n"; +} + + +static void UnitTestHTKCompare1() { + std::cout << "=== UnitTestHTKCompare1() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.1", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.1", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.1"); +} + + +static void UnitTestHTKCompare2() { + std::cout << "=== UnitTestHTKCompare2() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.2", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.2", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.2"); +} + + +static void UnitTestHTKCompare3() { + std::cout << "=== UnitTestHTKCompare3() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.3", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 20.0; + // op.mel_opts.debug_mel = true; + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.3", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.3"); +} + + +static void UnitTestHTKCompare4() { + std::cout << "=== UnitTestHTKCompare4() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.4", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.4", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.4"); +} + + +static void UnitTestHTKCompare5() { + std::cout << "=== UnitTestHTKCompare5() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.5", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 0.0; + op.mel_opts.vtln_low = 100.0; + op.mel_opts.vtln_high = 7500.0; + op.mel_opts.htk_mode = true; + + BaseFloat vtln_warp = + 1.1; // our approach identical to htk for warp factor >1, + // differs slightly for higher mel bins if warp_factor <0.9 + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.5", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.5"); +} + +static void UnitTestHTKCompare6() { + std::cout << "=== UnitTestHTKCompare6() ===\n"; + + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.6", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.97; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.num_bins = 24; + op.mel_opts.low_freq = 125.0; + op.mel_opts.high_freq = 7800.0; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, kaldi_raw_features, &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i + 10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " + << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " + << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + } + } + } + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float) * kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.6", + std::ios::out | std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.6"); +} + +void UnitTestVtln() { + // Test the function VtlnWarpFreq. + BaseFloat low_freq = 10, high_freq = 7800, vtln_low_cutoff = 20, + vtln_high_cutoff = 7400; + + for (size_t i = 0; i < 100; i++) { + BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2; + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq), + freq / warp_factor); + + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + low_freq), + low_freq); + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + high_freq), + high_freq); + BaseFloat freq2 = low_freq + (high_freq - low_freq) * RandUniform(), + freq3 = freq2 + + (high_freq - freq2) * RandUniform(); // freq3>=freq2 + BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq2); + BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + warp_factor, + freq3); + KALDI_ASSERT(w3 >= w2); // increasing function. + BaseFloat w3dash = MelBanks::VtlnWarpFreq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, 1.0, freq3); + AssertEqual(w3dash, freq3); + } +} + +static void UnitTestFeat() { + UnitTestVtln(); + UnitTestReadWave(); + UnitTestSimple(); + UnitTestHTKCompare1(); + UnitTestHTKCompare2(); + // commenting out this one as it doesn't compare right now I normalized + // the way the FFT bins are treated (removed offset of 0.5)... this seems + // to relate to the way frequency zero behaves. + UnitTestHTKCompare3(); + UnitTestHTKCompare4(); + UnitTestHTKCompare5(); + UnitTestHTKCompare6(); + std::cout << "Tests succeeded.\n"; +} + + +int main() { + try { + for (int i = 0; i < 5; i++) UnitTestFeat(); + std::cout << "Tests succeeded.\n"; + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return 1; + } +} diff --git a/speechx/examples/feat/linear-spectrogram-main.cc b/speechx/examples/feat/linear-spectrogram-main.cc new file mode 100644 index 00000000..3e2342c2 --- /dev/null +++ b/speechx/examples/feat/linear-spectrogram-main.cc @@ -0,0 +1,257 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "frontend/feature_extractor_interface.h" +#include "frontend/linear_spectrogram.h" +#include "frontend/normalizer.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test wav path"); +DEFINE_string(feature_wspecifier, "", "test wav ark"); +DEFINE_string(feature_check_wspecifier, "", "test wav ark"); +DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark"); + + +std::vector mean_{ + -13730251.531853663, -12982852.199316509, -13673844.299583456, + -13089406.559646806, -12673095.524938712, -12823859.223276224, + -13590267.158903603, -14257618.467152044, -14374605.116185192, + -14490009.21822485, -14849827.158924166, -15354435.470563512, + -15834149.206532761, -16172971.985514281, -16348740.496746974, + -16423536.699409386, -16556246.263649225, -16744088.772748645, + -16916184.08510357, -17054034.840031497, -17165612.509455364, + -17255955.470915023, -17322572.527648456, -17408943.862033736, + -17521554.799865916, -17620623.254924215, -17699792.395918526, + -17723364.411134344, -17741483.4433254, -17747426.888704527, + -17733315.928209435, -17748780.160905756, -17808336.883775543, + -17895918.671983004, -18009812.59173023, -18098188.66548325, + -18195798.958462656, -18293617.62980999, -18397432.92077201, + -18505834.787318766, -18585451.8100908, -18652438.235649142, + -18700960.306275308, -18734944.58792185, -18737426.313365128, + -18735347.165987637, -18738813.444170244, -18737086.848890636, + -18731576.2474336, -18717405.44095871, -18703089.25545657, + -18691014.546456724, -18692460.568905357, -18702119.628629155, + -18727710.621126678, -18761582.72034647, -18806745.835547544, + -18850674.8692112, -18884431.510951452, -18919999.992506847, + -18939303.799078144, -18952946.273760635, -18980289.22996379, + -19011610.17803294, -19040948.61805145, -19061021.429847397, + -19112055.53768819, -19149667.414264943, -19201127.05091321, + -19270250.82564605, -19334606.883057203, -19390513.336589377, + -19444176.259208687, -19502755.000038862, -19544333.014549147, + -19612668.183176614, -19681902.19006569, -19771969.951249883, + -19873329.723376893, -19996752.59235844, -20110031.131400537, + -20231658.612529557, -20319378.894054495, -20378534.45718066, + -20413332.089584175, -20438147.844177883, -20443710.248040095, + -20465457.02238927, -20488610.969337028, -20516295.16424432, + -20541423.795738827, -20553192.874953747, -20573605.50701977, + -20577871.61936797, -20571807.008916274, -20556242.38912231, + -20542199.30819195, -20521239.063551214, -20519150.80004532, + -20527204.80248933, -20536933.769257784, -20543470.522332076, + -20549700.089992985, -20551525.24958494, -20554873.406493705, + -20564277.65794227, -20572211.740052115, -20574305.69550465, + -20575494.450104576, -20567092.577932164, -20549302.929608088, + -20545445.11878376, -20546625.326603737, -20549190.03499401, + -20554824.947828256, -20568341.378989458, -20577582.331383612, + -20577980.519402675, -20566603.03458152, -20560131.592262644, + -20552166.469060015, -20549063.06763577, -20544490.562339947, + -20539817.82346569, -20528747.715731595, -20518026.24576161, + -20510977.844974525, -20506874.36087992, -20506731.11977665, + -20510482.133420516, -20507760.92101862, -20494644.834457114, + -20480107.89304893, -20461312.091867123, -20442941.75080173, + -20426123.02834838, -20424607.675283, -20426810.369107097, + -20434024.50097819, -20437404.75544205, -20447688.63916367, + -20460893.335563846, -20482922.735127095, -20503610.119434915, + -20527062.76448319, -20557830.035128627, -20593274.72068722, + -20632528.452965066, -20673637.471334763, -20733106.97143075, + -20842921.0447562, -21054357.83621519, -21416569.534189366, + -21978460.272811692, -22753170.052172784, -23671344.10563395, + -24613499.293358143, -25406477.12230188, -25884377.82156489, + -26049040.62791664, -26996879.104431007}; +std::vector variance_{ + 213747175.10846674, 188395815.34302503, 212706429.10966414, + 199109025.81461075, 189235901.23864496, 194901336.53253657, + 217481594.29306737, 238689869.12327808, 243977501.24115244, + 248479623.6431067, 259766741.47116545, 275516766.7790273, + 291271202.3691234, 302693239.8220509, 308627358.3997694, + 311143911.38788426, 315446105.07731867, 321705430.9341829, + 327458907.4659941, 332245072.43223983, 336251717.5935284, + 339694069.7639722, 342188204.4322228, 345587110.31313115, + 349903086.2875232, 353660214.20643026, 356700344.5270885, + 357665362.3529641, 358493352.05658793, 358857951.620328, + 358375239.52774596, 358899733.6342954, 361051818.3511561, + 364361716.05025816, 368750322.3771452, 372047800.6462831, + 375655861.1349018, 379358519.1980013, 383327605.3935181, + 387458599.282341, 390434692.3406868, 392994486.35057056, + 394874418.04603153, 396230525.79763395, 396365592.0414835, + 396334819.8242737, 396488353.19250053, 396438877.00744957, + 396197980.4459586, 395590921.6672991, 395001107.62072515, + 394528291.7318225, 394593110.424006, 395018405.59353715, + 396110577.5415993, 397506704.0371068, 399400197.4657644, + 401243568.2468382, 402687134.7805103, 404136047.2872507, + 404883170.001883, 405522253.219517, 406660365.3626476, + 407919346.0991902, 409045348.5384909, 409759588.7889818, + 411974821.8564483, 413489718.78201455, 415535392.56684107, + 418466481.97674364, 421104678.35678065, 423405392.5200779, + 425550570.40798235, 427929423.9579701, 429585274.253478, + 432368493.55181056, 435193587.13513297, 438886855.20476013, + 443058876.8633751, 448181232.5093362, 452883835.6332396, + 458056721.77926534, 461816531.22735566, 464363620.1970998, + 465886343.5057493, 466928872.0651, 467180536.42647296, + 468111848.70714295, 469138695.3071312, 470378429.6930793, + 471517958.7132626, 472109050.4262365, 473087417.0177867, + 473381322.04648733, 473220195.85483915, 472666071.8998819, + 472124669.87879956, 471298571.411737, 471251033.2902761, + 471672676.43128747, 472177147.2193172, 472572361.7711908, + 472968783.7751127, 473156295.4164052, 473398034.82676554, + 473897703.5203811, 474328271.33112127, 474452670.98002136, + 474549003.99284613, 474252887.13567275, 473557462.909069, + 473483385.85193115, 473609738.04855174, 473746944.82085115, + 474016729.91696435, 474617321.94138587, 475045097.237122, + 475125402.586558, 474664112.9824912, 474426247.5800283, + 474104075.42796475, 473978219.7273978, 473773171.7798875, + 473578534.69508696, 473102924.16904145, 472651240.5232615, + 472374383.1810912, 472209479.6956096, 472202298.8921673, + 472370090.76781124, 472220933.99374026, 471625467.37106377, + 470994646.51883453, 470182428.9637543, 469348211.5939578, + 468570387.4467277, 468540442.7225135, 468672018.90414184, + 468994346.9533251, 469138757.58201426, 469553915.95710236, + 470134523.38582784, 471082421.62055486, 471962316.51804745, + 472939745.1708408, 474250621.5944825, 475773933.43199486, + 477465399.71087736, 479218782.61382693, 481752299.7930922, + 486608947.8984568, 496119403.2067917, 512730085.5704984, + 539048915.2641417, 576285298.3548826, 621610270.2240586, + 669308196.4436442, 710656993.5957186, 736344437.3725077, + 745481288.0241544, 801121432.9925804}; +int count_ = 912592; + +void WriteMatrix() { + kaldi::Matrix cmvn_stats(2, mean_.size() + 1); + for (size_t idx = 0; idx < mean_.size(); ++idx) { + cmvn_stats(0, idx) = mean_[idx]; + cmvn_stats(1, idx) = variance_[idx]; + } + cmvn_stats(0, mean_.size()) = count_; + kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); +} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); + kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer( + FLAGS_feature_check_wspecifier); + WriteMatrix(); + + // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning + // window -->linear_spectrogram --> cmvn + int32 num_done = 0, num_err = 0; + ppspeech::LinearSpectrogramOptions opt; + opt.frame_opts.frame_length_ms = 20; + opt.frame_opts.frame_shift_ms = 10; + ppspeech::DecibelNormalizerOptions db_norm_opt; + std::unique_ptr base_feature_extractor( + new ppspeech::DecibelNormalizer(db_norm_opt)); + ppspeech::LinearSpectrogram linear_spectrogram( + opt, std::move(base_feature_extractor)); + + ppspeech::CMVN cmvn(FLAGS_cmvn_write_path); + + float streaming_chunk = 0.36; + int sample_rate = 16000; + int chunk_sample_size = streaming_chunk * sample_rate; + + LOG(INFO) << mean_.size(); + for (size_t i = 0; i < mean_.size(); i++) { + mean_[i] /= count_; + variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i]; + if (variance_[i] < 1.0e-20) { + variance_[i] = 1.0e-20; + } + variance_[i] = 1.0 / std::sqrt(variance_[i]); + } + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + kaldi::Matrix features; + linear_spectrogram.AcceptWaveform(wav_chunk); + linear_spectrogram.ReadFeats(&features); + + feats.push_back(features); + sample_offset += cur_chunk_size; + feature_rows += features.NumRows(); + } + + int cur_idx = 0; + kaldi::Matrix features(feature_rows, + feats[0].NumCols()); + for (auto feat : feats) { + for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { + for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { + features(cur_idx, col_idx) = + (feat(row_idx, col_idx) - mean_[col_idx]) * + variance_[col_idx]; + } + ++cur_idx; + } + } + feat_writer.Write(utt, features); + + cur_idx = 0; + kaldi::Matrix features_check(feature_rows, + feats[0].NumCols()); + for (auto feat : feats) { + for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { + for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { + features_check(cur_idx, col_idx) = feat(row_idx, col_idx); + } + kaldi::SubVector row_feat(features_check, cur_idx); + cmvn.ApplyCMVN(true, &row_feat); + ++cur_idx; + } + } + feat_cmvn_check_writer.Write(utt, features_check); + + if (num_done % 50 == 0 && num_done != 0) + KALDI_VLOG(2) << "Processed " << num_done << " utterances"; + num_done++; + } + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/examples/nnet/CMakeLists.txt b/speechx/examples/nnet/CMakeLists.txt new file mode 100644 index 00000000..20f4008c --- /dev/null +++ b/speechx/examples/nnet/CMakeLists.txt @@ -0,0 +1,5 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(pp-model-test ${CMAKE_CURRENT_SOURCE_DIR}/pp-model-test.cc) +target_include_directories(pp-model-test PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(pp-model-test PUBLIC nnet gflags ${DEPS}) \ No newline at end of file diff --git a/speechx/examples/nnet/pp-model-test.cc b/speechx/examples/nnet/pp-model-test.cc new file mode 100644 index 00000000..2db354a7 --- /dev/null +++ b/speechx/examples/nnet/pp-model-test.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle_inference_api.h" + +using std::cout; +using std::endl; + +DEFINE_string(model_path, "avg_1.jit.pdmodel", "xxx.pdmodel"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "xxx.pdiparams"); + + +void produce_data(std::vector>* data); +void model_forward_test(); + +void produce_data(std::vector>* data) { + int chunk_size = 35; // chunk_size in frame + int col_size = 161; // feat dim + cout << "chunk size: " << chunk_size << endl; + cout << "feat dim: " << col_size << endl; + + data->reserve(chunk_size); + data->back().reserve(col_size); + for (int row = 0; row < chunk_size; ++row) { + data->push_back(std::vector()); + for (int col_idx = 0; col_idx < col_size; ++col_idx) { + data->back().push_back(0.201); + } + } +} + +void model_forward_test() { + std::cout << "1. read the data" << std::endl; + std::vector> feats; + produce_data(&feats); + + std::cout << "2. load the model" << std::endl; + ; + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + cout << "model path: " << model_graph << endl; + cout << "model param path : " << model_params << endl; + + paddle_infer::Config config; + config.SetModel(model_graph, model_params); + config.SwitchIrOptim(false); + cout << "SwitchIrOptim: " << false << endl; + config.DisableFCPadding(); + cout << "DisableFCPadding: " << endl; + auto predictor = paddle_infer::CreatePredictor(config); + + std::cout << "3. feat shape, row=" << feats.size() + << ",col=" << feats[0].size() << std::endl; + std::vector pp_input_mat; + for (const auto& item : feats) { + pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end()); + } + + std::cout << "4. fead the data to model" << std::endl; + int row = feats.size(); + int col = feats[0].size(); + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + for (auto name : input_names) { + cout << "model input names: " << name << endl; + } + for (auto name : output_names) { + cout << "model output names: " << name << endl; + } + + // input + std::unique_ptr input_tensor = + predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(pp_input_mat.data()); + + // input length + std::unique_ptr input_len = + predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + + // state_h + std::unique_ptr chunk_state_h_box = + predictor->GetInputHandle(input_names[2]); + std::vector chunk_state_h_box_shape = {3, 1, 1024}; + chunk_state_h_box->Reshape(chunk_state_h_box_shape); + int chunk_state_h_box_size = + std::accumulate(chunk_state_h_box_shape.begin(), + chunk_state_h_box_shape.end(), + 1, + std::multiplies()); + std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); + chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); + + // state_c + std::unique_ptr chunk_state_c_box = + predictor->GetInputHandle(input_names[3]); + std::vector chunk_state_c_box_shape = {3, 1, 1024}; + chunk_state_c_box->Reshape(chunk_state_c_box_shape); + int chunk_state_c_box_size = + std::accumulate(chunk_state_c_box_shape.begin(), + chunk_state_c_box_shape.end(), + 1, + std::multiplies()); + std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); + chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); + + // run + bool success = predictor->Run(); + + // state_h out + std::unique_ptr h_out = + predictor->GetOutputHandle(output_names[2]); + std::vector h_out_shape = h_out->shape(); + int h_out_size = std::accumulate( + h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies()); + std::vector h_out_data(h_out_size); + h_out->CopyToCpu(h_out_data.data()); + + // stage_c out + std::unique_ptr c_out = + predictor->GetOutputHandle(output_names[3]); + std::vector c_out_shape = c_out->shape(); + int c_out_size = std::accumulate( + c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies()); + std::vector c_out_data(c_out_size); + c_out->CopyToCpu(c_out_data.data()); + + // output tensor + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + std::vector output_probs; + int output_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + output_probs.resize(output_size); + output_tensor->CopyToCpu(output_probs.data()); + row = output_shape[1]; + col = output_shape[2]; + + // probs + std::vector> probs; + probs.reserve(row); + for (int i = 0; i < row; i++) { + probs.push_back(std::vector()); + probs.back().reserve(col); + + for (int j = 0; j < col; j++) { + probs.back().push_back(output_probs[i * col + j]); + } + } + + std::vector> log_feat = probs; + std::cout << "probs, row: " << log_feat.size() + << " col: " << log_feat[0].size() << std::endl; + for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { + for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); + ++col_idx) { + std::cout << log_feat[row_idx][col_idx] << " "; + } + std::cout << std::endl; + } +} + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + model_forward_test(); + return 0; +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 4a296ec8..225abee7 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -30,16 +30,4 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) -add_subdirectory(decoder) - -add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) -target_link_libraries(mfcc-test kaldi-mfcc) - -add_executable(linear_spectrogram_main codelab/feat_test/linear_spectrogram_main.cc) -target_link_libraries(linear_spectrogram_main frontend kaldi-util kaldi-feat-common gflags glog) - -add_executable(offline_decoder_main codelab/decoder_test/offline_decoder_main.cc) -target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - -add_executable(model_test codelab/nnet_test/model_test.cc) -target_link_libraries(model_test PUBLIC nnet gflags ${DEPS}) +add_subdirectory(decoder) \ No newline at end of file diff --git a/speechx/speechx/base/basic_types.h b/speechx/speechx/base/basic_types.h index 1186efd5..206b7be6 100644 --- a/speechx/speechx/base/basic_types.h +++ b/speechx/speechx/base/basic_types.h @@ -18,22 +18,22 @@ #include -typedef float BaseFloat; -typedef double double64; +typedef float BaseFloat; +typedef double double64; -typedef signed char int8; -typedef short int16; -typedef int int32; +typedef signed char int8; +typedef short int16; +typedef int int32; #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) -typedef long int64; +typedef long int64; #else -typedef long long int64; +typedef long long int64; #endif -typedef unsigned char uint8; -typedef unsigned short uint16; -typedef unsigned int uint32; +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) typedef unsigned long uint64; @@ -41,20 +41,20 @@ typedef unsigned long uint64; typedef unsigned long long uint64; #endif -typedef signed int char32; - -const uint8 kuint8max = (( uint8) 0xFF); -const uint16 kuint16max = ((uint16) 0xFFFF); -const uint32 kuint32max = ((uint32) 0xFFFFFFFF); -const uint64 kuint64max = ((uint64) (0xFFFFFFFFFFFFFFFFLL)); -const int8 kint8min = (( int8) 0x80); -const int8 kint8max = (( int8) 0x7F); -const int16 kint16min = (( int16) 0x8000); -const int16 kint16max = (( int16) 0x7FFF); -const int32 kint32min = (( int32) 0x80000000); -const int32 kint32max = (( int32) 0x7FFFFFFF); -const int64 kint64min = (( int64) (0x8000000000000000LL)); -const int64 kint64max = (( int64) (0x7FFFFFFFFFFFFFFFLL)); - -const BaseFloat kBaseFloatMax = std::numeric_limits::max(); -const BaseFloat kBaseFloatMin = std::numeric_limits::min(); +typedef signed int char32; + +const uint8 kuint8max = ((uint8)0xFF); +const uint16 kuint16max = ((uint16)0xFFFF); +const uint32 kuint32max = ((uint32)0xFFFFFFFF); +const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL)); +const int8 kint8min = ((int8)0x80); +const int8 kint8max = ((int8)0x7F); +const int16 kint16min = ((int16)0x8000); +const int16 kint16max = ((int16)0x7FFF); +const int32 kint32min = ((int32)0x80000000); +const int32 kint32max = ((int32)0x7FFFFFFF); +const int64 kint64min = ((int64)(0x8000000000000000LL)); +const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL)); + +const BaseFloat kBaseFloatMax = std::numeric_limits::max(); +const BaseFloat kBaseFloatMin = std::numeric_limits::min(); diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index f4261e55..3b58f73c 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -15,22 +15,22 @@ #pragma once #include +#include #include #include -#include #include #include +#include #include #include #include #include #include -#include #include #include -#include +#include -#include "base/log.h" -#include "base/flags.h" #include "base/basic_types.h" +#include "base/flags.h" +#include "base/log.h" #include "base/macros.h" diff --git a/speechx/speechx/base/macros.h b/speechx/speechx/base/macros.h index 17254887..d7d5a78d 100644 --- a/speechx/speechx/base/macros.h +++ b/speechx/speechx/base/macros.h @@ -18,8 +18,8 @@ namespace ppspeech { #ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName&) = delete; \ - void operator=(const TypeName&) = delete + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete #endif } // namespace pp_speech \ No newline at end of file diff --git a/speechx/speechx/base/thread_pool.h b/speechx/speechx/base/thread_pool.h index 3405af9d..ba895f71 100644 --- a/speechx/speechx/base/thread_pool.h +++ b/speechx/speechx/base/thread_pool.h @@ -23,98 +23,88 @@ #ifndef BASE_THREAD_POOL_H #define BASE_THREAD_POOL_H -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include +#include #include +#include +#include class ThreadPool { -public: + public: ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) + template + auto enqueue(F&& f, Args&&... args) -> std::future::type>; ~ThreadPool(); -private: + + private: // need to keep track of threads so we can join them - std::vector< std::thread > workers; + std::vector workers; // the task queue - std::queue< std::function > tasks; - + std::queue> tasks; + // synchronization std::mutex queue_mutex; std::condition_variable condition; bool stop; }; - + // the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) - : stop(false) -{ - for(size_t i = 0;i task; + { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, [this] { + return this->stop || !this->tasks.empty(); + }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); } + + task(); } - ); + }); } // add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> -{ +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { using return_type = typename std::result_of::type; - auto task = std::make_shared< std::packaged_task >( - std::bind(std::forward(f), std::forward(args)...) - ); - + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + std::future res = task->get_future(); { std::unique_lock lock(queue_mutex); // don't allow enqueueing after stopping the pool - if(stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); + if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); - tasks.emplace([task](){ (*task)(); }); + tasks.emplace([task]() { (*task)(); }); } condition.notify_one(); return res; } // the destructor joins all threads -inline ThreadPool::~ThreadPool() -{ +inline ThreadPool::~ThreadPool() { { std::unique_lock lock(queue_mutex); stop = true; } condition.notify_all(); - for(std::thread &worker: workers) - worker.join(); + for (std::thread& worker : workers) worker.join(); } #endif diff --git a/speechx/speechx/codelab/README.md b/speechx/speechx/codelab/README.md deleted file mode 100644 index 95c95db1..00000000 --- a/speechx/speechx/codelab/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# codelab - -This directory is here for testing some funcitons temporaril. - diff --git a/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc b/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc deleted file mode 100644 index 138f5eeb..00000000 --- a/speechx/speechx/codelab/decoder_test/offline_decoder_main.cc +++ /dev/null @@ -1,57 +0,0 @@ -// todo refactor, repalce with gtest - -#include "decoder/ctc_beam_search_decoder.h" -#include "kaldi/util/table-types.h" -#include "base/log.h" -#include "base/flags.h" -#include "nnet/paddle_nnet.h" -#include "nnet/decodable.h" - -DEFINE_string(feature_respecifier, "", "test nnet prob"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -//void SplitFeature(kaldi::Matrix feature, -// int32 chunk_size, -// std::vector* feature_chunks) { - -//} - -int main(int argc, char* argv[]) { - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - - kaldi::SequentialBaseFloatMatrixReader feature_reader(FLAGS_feature_respecifier); - - // test nnet_output --> decoder result - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - ppspeech::CTCBeamSearch decoder(opts); - - ppspeech::ModelOptions model_opts; - std::shared_ptr nnet(new ppspeech::PaddleNnet(model_opts)); - - std::shared_ptr decodable(new ppspeech::Decodable(nnet)); - - //int32 chunk_size = 35; - decoder.InitDecoder(); - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - const kaldi::Matrix feature = feature_reader.Value(); - decodable->FeedFeatures(feature); - decoder.AdvanceDecode(decodable, 8); - decodable->InputFinished(); - std::string result; - result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; - decodable->Reset(); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} \ No newline at end of file diff --git a/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc b/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc deleted file mode 100644 index c4367139..00000000 --- a/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc +++ /dev/null @@ -1,686 +0,0 @@ -// feat/feature-mfcc-test.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include - -#include "feat/feature-mfcc.h" -#include "base/kaldi-math.h" -#include "matrix/kaldi-matrix-inl.h" -#include "feat/wave-reader.h" - -using namespace kaldi; - - - -static void UnitTestReadWave() { - - std::cout << "=== UnitTestReadWave() ===\n"; - - Vector v, v2; - - std::cout << "<<<=== Reading waveform\n"; - - { - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - const Matrix data(wave.Data()); - KALDI_ASSERT(data.NumRows() == 1); - v.Resize(data.NumCols()); - v.CopyFromVec(data.Row(0)); - } - - std::cout << "<<<=== Reading Vector waveform, prepared by matlab\n"; - std::ifstream input( - "test_data/test_matlab.ascii" - ); - KALDI_ASSERT(input.good()); - v2.Read(input, false); - input.close(); - - std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n"; - KALDI_ASSERT(v.Dim() == v2.Dim()); - for (int32 i = 0; i < v.Dim(); i++) { - KALDI_ASSERT(v(i) == v2(i)); - } - std::cout << "<<<=== Comparing done\n"; - - // std::cout << "== The Waveform Samples == \n"; - // std::cout << v; - - std::cout << "Test passed :)\n\n"; - -} - - - -/** - */ -static void UnitTestSimple() { - std::cout << "=== UnitTestSimple() ===\n"; - - Vector v(100000); - Matrix m; - - // init with noise - for (int32 i = 0; i < v.Dim(); i++) { - v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2); - } - - std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; - // the parametrization object - MfccOptions op; - // trying to have same opts as baseline. - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "rectangular"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - - Mfcc mfcc(op); - // use default parameters - - // compute mfccs. - mfcc.Compute(v, 1.0, &m); - - // possibly dump - // std::cout << "== Output features == \n" << m; - std::cout << "Test passed :)\n\n"; -} - - -static void UnitTestHTKCompare1() { - std::cout << "=== UnitTestHTKCompare1() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.1", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - op.use_energy = false; // C0 not energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (i_old != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.1", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.1"); -} - - -static void UnitTestHTKCompare2() { - std::cout << "=== UnitTestHTKCompare2() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.2", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.mel_opts.htk_mode = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (i_old != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.2", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.2"); -} - - -static void UnitTestHTKCompare3() { - std::cout << "=== UnitTestHTKCompare3() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.3", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.low_freq = 20.0; - //op.mel_opts.debug_mel = true; - op.mel_opts.htk_mode = true; - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.3", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.3"); -} - - -static void UnitTestHTKCompare4() { - std::cout << "=== UnitTestHTKCompare4() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.4", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.low_freq = 0.0; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.htk_mode = true; - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.4", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.4"); -} - - -static void UnitTestHTKCompare5() { - std::cout << "=== UnitTestHTKCompare5() ===\n"; - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.5", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.htk_compat = true; - op.use_energy = true; // Use energy. - op.mel_opts.low_freq = 0.0; - op.mel_opts.vtln_low = 100.0; - op.mel_opts.vtln_high = 7500.0; - op.mel_opts.htk_mode = true; - - BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1, - // differs slightly for higher mel bins if warp_factor <0.9 - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.5", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.5"); -} - -static void UnitTestHTKCompare6() { - std::cout << "=== UnitTestHTKCompare6() ===\n"; - - - std::ifstream is("test_data/test.wav", std::ios_base::binary); - WaveData wave; - wave.Read(is); - KALDI_ASSERT(wave.Data().NumRows() == 1); - SubVector waveform(wave.Data(), 0); - - // read the HTK features - Matrix htk_features; - { - std::ifstream is("test_data/test.wav.fea_htk.6", - std::ios::in | std::ios_base::binary); - bool ans = ReadHtk(is, &htk_features, 0); - KALDI_ASSERT(ans); - } - - // use mfcc with default configuration... - MfccOptions op; - op.frame_opts.dither = 0.0; - op.frame_opts.preemph_coeff = 0.97; - op.frame_opts.window_type = "hamming"; - op.frame_opts.remove_dc_offset = false; - op.frame_opts.round_to_power_of_two = true; - op.mel_opts.num_bins = 24; - op.mel_opts.low_freq = 125.0; - op.mel_opts.high_freq = 7800.0; - op.htk_compat = true; - op.use_energy = false; // C0 not energy. - - Mfcc mfcc(op); - - // calculate kaldi features - Matrix kaldi_raw_features; - mfcc.Compute(waveform, 1.0, &kaldi_raw_features); - - DeltaFeaturesOptions delta_opts; - Matrix kaldi_features; - ComputeDeltas(delta_opts, - kaldi_raw_features, - &kaldi_features); - - // compare the results - bool passed = true; - int32 i_old = -1; - KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); - KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); - // Ignore ends-- we make slightly different choices than - // HTK about how to treat the deltas at the ends. - for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { - for (int32 j = 0; j < kaldi_features.NumCols(); j++) { - BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); - if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! - // print the non-matching data only once per-line - if (static_cast(i_old) != i) { - std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; - std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; - i_old = i; - } - // print indices of non-matching cells - std::cout << "[" << i << ", " << j << "]"; - passed = false; - }}} - if (!passed) KALDI_ERR << "Test failed"; - - // write the htk features for later inspection - HtkHeader header = { - kaldi_features.NumRows(), - 100000, // 10ms - static_cast(sizeof(float)*kaldi_features.NumCols()), - 021406 // MFCC_D_A_0 - }; - { - std::ofstream os("tmp.test.wav.fea_kaldi.6", - std::ios::out|std::ios::binary); - WriteHtk(os, kaldi_features, header); - } - - std::cout << "Test passed :)\n\n"; - - unlink("tmp.test.wav.fea_kaldi.6"); -} - -void UnitTestVtln() { - // Test the function VtlnWarpFreq. - BaseFloat low_freq = 10, high_freq = 7800, - vtln_low_cutoff = 20, vtln_high_cutoff = 7400; - - for (size_t i = 0; i < 100; i++) { - BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2; - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq), - freq / warp_factor); - - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - low_freq), - low_freq); - AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - high_freq), - high_freq); - BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(), - freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2 - BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq2); - BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, warp_factor, - freq3); - KALDI_ASSERT(w3 >= w2); // increasing function. - BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, 1.0, - freq3); - AssertEqual(w3dash, freq3); - } -} - -static void UnitTestFeat() { - UnitTestVtln(); - UnitTestReadWave(); - UnitTestSimple(); - UnitTestHTKCompare1(); - UnitTestHTKCompare2(); - // commenting out this one as it doesn't compare right now I normalized - // the way the FFT bins are treated (removed offset of 0.5)... this seems - // to relate to the way frequency zero behaves. - UnitTestHTKCompare3(); - UnitTestHTKCompare4(); - UnitTestHTKCompare5(); - UnitTestHTKCompare6(); - std::cout << "Tests succeeded.\n"; -} - - - -int main() { - try { - for (int i = 0; i < 5; i++) - UnitTestFeat(); - std::cout << "Tests succeeded.\n"; - return 0; - } catch (const std::exception &e) { - std::cerr << e.what(); - return 1; - } -} - - diff --git a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc b/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc deleted file mode 100644 index 3cd1ae61..00000000 --- a/speechx/speechx/codelab/feat_test/linear_spectrogram_main.cc +++ /dev/null @@ -1,125 +0,0 @@ -// todo refactor, repalce with gtest - -#include "frontend/linear_spectrogram.h" -#include "frontend/normalizer.h" -#include "frontend/feature_extractor_interface.h" -#include "kaldi/util/table-types.h" -#include "base/log.h" -#include "base/flags.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/kaldi-io.h" - -DEFINE_string(wav_rspecifier, "", "test wav path"); -DEFINE_string(feature_wspecifier, "", "test wav ark"); -DEFINE_string(feature_check_wspecifier, "", "test wav ark"); -DEFINE_string(cmvn_write_path, "./cmvn.ark", "test wav ark"); - - -std::vector mean_{-13730251.531853663, -12982852.199316509, -13673844.299583456, -13089406.559646806, -12673095.524938712, -12823859.223276224, -13590267.158903603, -14257618.467152044, -14374605.116185192, -14490009.21822485, -14849827.158924166, -15354435.470563512, -15834149.206532761, -16172971.985514281, -16348740.496746974, -16423536.699409386, -16556246.263649225, -16744088.772748645, -16916184.08510357, -17054034.840031497, -17165612.509455364, -17255955.470915023, -17322572.527648456, -17408943.862033736, -17521554.799865916, -17620623.254924215, -17699792.395918526, -17723364.411134344, -17741483.4433254, -17747426.888704527, -17733315.928209435, -17748780.160905756, -17808336.883775543, -17895918.671983004, -18009812.59173023, -18098188.66548325, -18195798.958462656, -18293617.62980999, -18397432.92077201, -18505834.787318766, -18585451.8100908, -18652438.235649142, -18700960.306275308, -18734944.58792185, -18737426.313365128, -18735347.165987637, -18738813.444170244, -18737086.848890636, -18731576.2474336, -18717405.44095871, -18703089.25545657, -18691014.546456724, -18692460.568905357, -18702119.628629155, -18727710.621126678, -18761582.72034647, -18806745.835547544, -18850674.8692112, -18884431.510951452, -18919999.992506847, -18939303.799078144, -18952946.273760635, -18980289.22996379, -19011610.17803294, -19040948.61805145, -19061021.429847397, -19112055.53768819, -19149667.414264943, -19201127.05091321, -19270250.82564605, -19334606.883057203, -19390513.336589377, -19444176.259208687, -19502755.000038862, -19544333.014549147, -19612668.183176614, -19681902.19006569, -19771969.951249883, -19873329.723376893, -19996752.59235844, -20110031.131400537, -20231658.612529557, -20319378.894054495, -20378534.45718066, -20413332.089584175, -20438147.844177883, -20443710.248040095, -20465457.02238927, -20488610.969337028, -20516295.16424432, -20541423.795738827, -20553192.874953747, -20573605.50701977, -20577871.61936797, -20571807.008916274, -20556242.38912231, -20542199.30819195, -20521239.063551214, -20519150.80004532, -20527204.80248933, -20536933.769257784, -20543470.522332076, -20549700.089992985, -20551525.24958494, -20554873.406493705, -20564277.65794227, -20572211.740052115, -20574305.69550465, -20575494.450104576, -20567092.577932164, -20549302.929608088, -20545445.11878376, -20546625.326603737, -20549190.03499401, -20554824.947828256, -20568341.378989458, -20577582.331383612, -20577980.519402675, -20566603.03458152, -20560131.592262644, -20552166.469060015, -20549063.06763577, -20544490.562339947, -20539817.82346569, -20528747.715731595, -20518026.24576161, -20510977.844974525, -20506874.36087992, -20506731.11977665, -20510482.133420516, -20507760.92101862, -20494644.834457114, -20480107.89304893, -20461312.091867123, -20442941.75080173, -20426123.02834838, -20424607.675283, -20426810.369107097, -20434024.50097819, -20437404.75544205, -20447688.63916367, -20460893.335563846, -20482922.735127095, -20503610.119434915, -20527062.76448319, -20557830.035128627, -20593274.72068722, -20632528.452965066, -20673637.471334763, -20733106.97143075, -20842921.0447562, -21054357.83621519, -21416569.534189366, -21978460.272811692, -22753170.052172784, -23671344.10563395, -24613499.293358143, -25406477.12230188, -25884377.82156489, -26049040.62791664, -26996879.104431007}; -std::vector variance_{213747175.10846674, 188395815.34302503, 212706429.10966414, 199109025.81461075, 189235901.23864496, 194901336.53253657, 217481594.29306737, 238689869.12327808, 243977501.24115244, 248479623.6431067, 259766741.47116545, 275516766.7790273, 291271202.3691234, 302693239.8220509, 308627358.3997694, 311143911.38788426, 315446105.07731867, 321705430.9341829, 327458907.4659941, 332245072.43223983, 336251717.5935284, 339694069.7639722, 342188204.4322228, 345587110.31313115, 349903086.2875232, 353660214.20643026, 356700344.5270885, 357665362.3529641, 358493352.05658793, 358857951.620328, 358375239.52774596, 358899733.6342954, 361051818.3511561, 364361716.05025816, 368750322.3771452, 372047800.6462831, 375655861.1349018, 379358519.1980013, 383327605.3935181, 387458599.282341, 390434692.3406868, 392994486.35057056, 394874418.04603153, 396230525.79763395, 396365592.0414835, 396334819.8242737, 396488353.19250053, 396438877.00744957, 396197980.4459586, 395590921.6672991, 395001107.62072515, 394528291.7318225, 394593110.424006, 395018405.59353715, 396110577.5415993, 397506704.0371068, 399400197.4657644, 401243568.2468382, 402687134.7805103, 404136047.2872507, 404883170.001883, 405522253.219517, 406660365.3626476, 407919346.0991902, 409045348.5384909, 409759588.7889818, 411974821.8564483, 413489718.78201455, 415535392.56684107, 418466481.97674364, 421104678.35678065, 423405392.5200779, 425550570.40798235, 427929423.9579701, 429585274.253478, 432368493.55181056, 435193587.13513297, 438886855.20476013, 443058876.8633751, 448181232.5093362, 452883835.6332396, 458056721.77926534, 461816531.22735566, 464363620.1970998, 465886343.5057493, 466928872.0651, 467180536.42647296, 468111848.70714295, 469138695.3071312, 470378429.6930793, 471517958.7132626, 472109050.4262365, 473087417.0177867, 473381322.04648733, 473220195.85483915, 472666071.8998819, 472124669.87879956, 471298571.411737, 471251033.2902761, 471672676.43128747, 472177147.2193172, 472572361.7711908, 472968783.7751127, 473156295.4164052, 473398034.82676554, 473897703.5203811, 474328271.33112127, 474452670.98002136, 474549003.99284613, 474252887.13567275, 473557462.909069, 473483385.85193115, 473609738.04855174, 473746944.82085115, 474016729.91696435, 474617321.94138587, 475045097.237122, 475125402.586558, 474664112.9824912, 474426247.5800283, 474104075.42796475, 473978219.7273978, 473773171.7798875, 473578534.69508696, 473102924.16904145, 472651240.5232615, 472374383.1810912, 472209479.6956096, 472202298.8921673, 472370090.76781124, 472220933.99374026, 471625467.37106377, 470994646.51883453, 470182428.9637543, 469348211.5939578, 468570387.4467277, 468540442.7225135, 468672018.90414184, 468994346.9533251, 469138757.58201426, 469553915.95710236, 470134523.38582784, 471082421.62055486, 471962316.51804745, 472939745.1708408, 474250621.5944825, 475773933.43199486, 477465399.71087736, 479218782.61382693, 481752299.7930922, 486608947.8984568, 496119403.2067917, 512730085.5704984, 539048915.2641417, 576285298.3548826, 621610270.2240586, 669308196.4436442, 710656993.5957186, 736344437.3725077, 745481288.0241544, 801121432.9925804}; -int count_ = 912592; - -void WriteMatrix() { - kaldi::Matrix cmvn_stats(2, mean_.size()+ 1); - for (size_t idx = 0; idx < mean_.size(); ++idx) { - cmvn_stats(0, idx) = mean_[idx]; - cmvn_stats(1, idx) = variance_[idx]; - } - cmvn_stats(0, mean_.size()) = count_; - kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, true); -} - -int main(int argc, char* argv[]) { - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - - kaldi::SequentialTableReader wav_reader(FLAGS_wav_rspecifier); - kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); - kaldi::BaseFloatMatrixWriter feat_cmvn_check_writer(FLAGS_feature_check_wspecifier); - WriteMatrix(); - - // test feature linear_spectorgram: wave --> decibel_normalizer --> hanning window -->linear_spectrogram --> cmvn - int32 num_done = 0, num_err = 0; - ppspeech::LinearSpectrogramOptions opt; - opt.frame_opts.frame_length_ms = 20; - opt.frame_opts.frame_shift_ms = 10; - ppspeech::DecibelNormalizerOptions db_norm_opt; - std::unique_ptr base_feature_extractor( - new ppspeech::DecibelNormalizer(db_norm_opt)); - ppspeech::LinearSpectrogram linear_spectrogram(opt, std::move(base_feature_extractor)); - - ppspeech::CMVN cmvn(FLAGS_cmvn_write_path); - - float streaming_chunk = 0.36; - int sample_rate = 16000; - int chunk_sample_size = streaming_chunk * sample_rate; - - LOG(INFO) << mean_.size(); - for (size_t i = 0; i < mean_.size(); i++) { - mean_[i] /= count_; - variance_[i] = variance_[i] / count_ - mean_[i] * mean_[i]; - if (variance_[i] < 1.0e-20) { - variance_[i] = 1.0e-20; - } - variance_[i] = 1.0 / std::sqrt(variance_[i]); - } - - for (; !wav_reader.Done(); wav_reader.Next()) { - std::string utt = wav_reader.Key(); - const kaldi::WaveData &wave_data = wav_reader.Value(); - - int32 this_channel = 0; - kaldi::SubVector waveform(wave_data.Data(), this_channel); - int tot_samples = waveform.Dim(); - int sample_offset = 0; - std::vector> feats; - int feature_rows = 0; - while (sample_offset < tot_samples) { - int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); - } - kaldi::Matrix features; - linear_spectrogram.AcceptWaveform(wav_chunk); - linear_spectrogram.ReadFeats(&features); - - feats.push_back(features); - sample_offset += cur_chunk_size; - feature_rows += features.NumRows(); - } - - int cur_idx = 0; - kaldi::Matrix features(feature_rows, feats[0].NumCols()); - for (auto feat : feats) { - for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { - for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { - features(cur_idx, col_idx) = (feat(row_idx, col_idx) - mean_[col_idx])*variance_[col_idx]; - } - ++cur_idx; - } - } - feat_writer.Write(utt, features); - - cur_idx = 0; - kaldi::Matrix features_check(feature_rows, feats[0].NumCols()); - for (auto feat : feats) { - for (int row_idx = 0; row_idx < feat.NumRows(); ++row_idx) { - for (int col_idx = 0; col_idx < feat.NumCols(); ++col_idx) { - features_check(cur_idx, col_idx) = feat(row_idx, col_idx); - } - kaldi::SubVector row_feat(features_check, cur_idx); - cmvn.ApplyCMVN(true, &row_feat); - ++cur_idx; - } - } - feat_cmvn_check_writer.Write(utt, features_check); - - if (num_done % 50 == 0 && num_done != 0) - KALDI_VLOG(2) << "Processed " << num_done << " utterances"; - num_done++; - } - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/codelab/nnet_test/model_test.cc b/speechx/speechx/codelab/nnet_test/model_test.cc deleted file mode 100644 index ce1e7fff..00000000 --- a/speechx/speechx/codelab/nnet_test/model_test.cc +++ /dev/null @@ -1,134 +0,0 @@ -#include "paddle_inference_api.h" -#include -#include -#include -#include -#include -#include -#include -#include - -void produce_data(std::vector>* data); -void model_forward_test(); - -int main(int argc, char* argv[]) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - model_forward_test(); - return 0; -} - -void model_forward_test() { - std::cout << "1. read the data" << std::endl; - std::vector> feats; - produce_data(&feats); - - std::cout << "2. load the model" << std::endl;; - std::string model_graph = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"; - std::string model_params = "../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"; - paddle_infer::Config config; - config.SetModel(model_graph, model_params); - config.SwitchIrOptim(false); - config.DisableFCPadding(); - auto predictor = paddle_infer::CreatePredictor(config); - - std::cout << "3. feat shape, row=" << feats.size() << ",col=" << feats[0].size() << std::endl; - std::vector paddle_input_feature_matrix; - for(const auto& item : feats) { - paddle_input_feature_matrix.insert(paddle_input_feature_matrix.end(), item.begin(), item.end()); - } - - std::cout << "4. fead the data to model" << std::endl; - int row = feats.size(); - int col = feats[0].size(); - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(paddle_input_feature_matrix.data()); - - std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(row); - input_len->CopyFromCpu(audio_len.data()); - - std::unique_ptr chunk_state_h_box = predictor->GetInputHandle(input_names[2]); - std::vector chunk_state_h_box_shape = {3, 1, 1024}; - chunk_state_h_box->Reshape(chunk_state_h_box_shape); - int chunk_state_h_box_size = std::accumulate(chunk_state_h_box_shape.begin(), chunk_state_h_box_shape.end(), - 1, std::multiplies()); - std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); - chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); - - std::unique_ptr chunk_state_c_box = predictor->GetInputHandle(input_names[3]); - std::vector chunk_state_c_box_shape = {3, 1, 1024}; - chunk_state_c_box->Reshape(chunk_state_c_box_shape); - int chunk_state_c_box_size = std::accumulate(chunk_state_c_box_shape.begin(), chunk_state_c_box_shape.end(), - 1, std::multiplies()); - std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); - chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); - - bool success = predictor->Run(); - - std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); - std::vector h_out_shape = h_out->shape(); - int h_out_size = std::accumulate(h_out_shape.begin(), h_out_shape.end(), - 1, std::multiplies()); - std::vector h_out_data(h_out_size); - h_out->CopyToCpu(h_out_data.data()); - - std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); - std::vector c_out_shape = c_out->shape(); - int c_out_size = std::accumulate(c_out_shape.begin(), c_out_shape.end(), - 1, std::multiplies()); - std::vector c_out_data(c_out_size); - c_out->CopyToCpu(c_out_data.data()); - - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - std::vector output_probs; - int output_size = std::accumulate(output_shape.begin(), output_shape.end(), - 1, std::multiplies()); - output_probs.resize(output_size); - output_tensor->CopyToCpu(output_probs.data()); - row = output_shape[1]; - col = output_shape[2]; - - std::vector> probs; - probs.reserve(row); - for (int i = 0; i < row; i++) { - probs.push_back(std::vector()); - probs.back().reserve(col); - - for (int j = 0; j < col; j++) { - probs.back().push_back(output_probs[i * col + j]); - } - } - - std::vector> log_feat = probs; - std::cout << "probs, row: " << log_feat.size() << " col: " << log_feat[0].size() << std::endl; - for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { - for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); ++col_idx) { - std::cout << log_feat[row_idx][col_idx] << " "; - } - std::cout << std::endl; - } -} - -void produce_data(std::vector>* data) { - int chunk_size = 35; - int col_size = 161; - data->reserve(chunk_size); - data->back().reserve(col_size); - for (int row = 0; row < chunk_size; ++row) { - data->push_back(std::vector()); - for (int col_idx = 0; col_idx < col_size; ++col_idx) { - data->back().push_back(0.201); - } - } -} diff --git a/speechx/speechx/decoder/common.h b/speechx/speechx/decoder/common.h index 4292a871..52deffac 100644 --- a/speechx/speechx/decoder/common.h +++ b/speechx/speechx/decoder/common.h @@ -1,7 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "base/basic_types.h" struct DecoderResult { - BaseFloat acoustic_score; - std::vector words_idx; - std::vector> time_stamp; + BaseFloat acoustic_score; + std::vector words_idx; + std::vector> time_stamp; }; diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 62abf377..92c57858 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "decoder/ctc_beam_search_decoder.h" #include "base/basic_types.h" @@ -9,292 +23,290 @@ namespace ppspeech { using std::vector; using FSTMATCH = fst::SortedMatcher; -CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) : - opts_(opts), - init_ext_scorer_(nullptr), - blank_id(-1), - space_id(-1), - num_frame_decoded_(0), - root(nullptr) { - +CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) + : opts_(opts), + init_ext_scorer_(nullptr), + blank_id(-1), + space_id(-1), + num_frame_decoded_(0), + root(nullptr) { LOG(INFO) << "dict path: " << opts_.dict_file; if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { LOG(INFO) << "load the dict failed"; } - LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size(); + LOG(INFO) << "read the vocabulary success, dict size: " + << vocabulary_.size(); LOG(INFO) << "language model path: " << opts_.lm_path; - init_ext_scorer_ = std::make_shared(opts_.alpha, - opts_.beta, - opts_.lm_path, - vocabulary_); + init_ext_scorer_ = std::make_shared( + opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); } void CTCBeamSearch::Reset() { - num_frame_decoded_ = 0; - ResetPrefixes(); + num_frame_decoded_ = 0; + ResetPrefixes(); } void CTCBeamSearch::InitDecoder() { - blank_id = 0; auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - + space_id = it - vocabulary_.begin(); // if no space in vocabulary if ((size_t)space_id >= vocabulary_.size()) { space_id = -2; - } + } ResetPrefixes(); - + root = std::make_shared(); root->score = root->log_prob_b_prev = 0.0; prefixes.push_back(root.get()); - if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { + if (init_ext_scorer_ != nullptr && + !init_ext_scorer_->is_character_based()) { auto fst_dict = - static_cast(init_ext_scorer_->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + static_cast(init_ext_scorer_->dictionary); + fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); root->set_dictionary(dict_ptr); - + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); root->set_matcher(matcher); } } -void CTCBeamSearch::Decode(std::shared_ptr decodable) { - return; +void CTCBeamSearch::Decode( + std::shared_ptr decodable) { + return; } -int32 CTCBeamSearch::NumFrameDecoded() { - return num_frame_decoded_; -} +int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; } // todo rename, refactor -void CTCBeamSearch::AdvanceDecode(const std::shared_ptr& decodable, - int max_frames) { +void CTCBeamSearch::AdvanceDecode( + const std::shared_ptr& decodable, + int max_frames) { while (max_frames > 0) { - vector> likelihood; - if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { - break; - } - likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); - AdvanceDecoding(likelihood); - max_frames--; + vector> likelihood; + if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { + break; + } + likelihood.push_back( + decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); + AdvanceDecoding(likelihood); + max_frames--; } } void CTCBeamSearch::ResetPrefixes() { - for (size_t i = 0; i < prefixes.size(); i++) { - if (prefixes[i] != nullptr) { - delete prefixes[i]; - prefixes[i] = nullptr; + for (size_t i = 0; i < prefixes.size(); i++) { + if (prefixes[i] != nullptr) { + delete prefixes[i]; + prefixes[i] = nullptr; + } } - } } -int CTCBeamSearch::DecodeLikelihoods(const vector>&probs, - vector& nbest_words) { - kaldi::Timer timer; - timer.Reset(); - AdvanceDecoding(probs); - LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast(timer.Elapsed()) / 1000.0f; - return 0; -} +int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, + vector& nbest_words) { + kaldi::Timer timer; + timer.Reset(); + AdvanceDecoding(probs); + LOG(INFO) << "ctc decoding elapsed time(s) " + << static_cast(timer.Elapsed()) / 1000.0f; + return 0; +} vector> CTCBeamSearch::GetNBestPath() { - return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); + return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); } string CTCBeamSearch::GetBestPath() { - std::vector> result; - result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); - return result[0].second; + std::vector> result; + result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size); + return result[0].second; } string CTCBeamSearch::GetFinalBestPath() { - CalculateApproxScore(); - LMRescore(); - return GetBestPath(); + CalculateApproxScore(); + LMRescore(); + return GetBestPath(); } void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { - size_t num_time_steps = probs.size(); - size_t beam_size = opts_.beam_size; - double cutoff_prob = opts_.cutoff_prob; - size_t cutoff_top_n = opts_.cutoff_top_n; - - vector> probs_seq(probs.size(), vector(probs[0].size(), 0)); - - int row = probs.size(); - int col = probs[0].size(); - for(int i = 0; i < row; i++) { - for (int j = 0; j < col; j++){ - probs_seq[i][j] = static_cast(probs[i][j]); - } - } - - for (size_t time_step = 0; time_step < num_time_steps; time_step++) { - const auto& prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (init_ext_scorer_ != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, - prefix_compare); - - if (num_prefixes == 0) { - continue; - } - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, init_ext_scorer_->beta); - - full_beam = (num_prefixes == beam_size); + size_t num_time_steps = probs.size(); + size_t beam_size = opts_.beam_size; + double cutoff_prob = opts_.cutoff_prob; + size_t cutoff_top_n = opts_.cutoff_top_n; + + vector> probs_seq(probs.size(), + vector(probs[0].size(), 0)); + + int row = probs.size(); + int col = probs[0].size(); + for (int i = 0; i < row; i++) { + for (int j = 0; j < col; j++) { + probs_seq[i][j] = static_cast(probs[i][j]); + } } - - vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - - // loop over chars - size_t log_prob_idx_len = log_prob_idx.size(); - for (size_t index = 0; index < log_prob_idx_len; index++) { - SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); - } - - prefixes.clear(); - - // update log probs - root->iterate_to_vec(prefixes); - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } // if - num_frame_decoded_++; - } // for probs_seq -} -int32 CTCBeamSearch::SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff) { - size_t beam_size = opts_.beam_size; - const auto& c = log_prob_idx.first; - const auto& log_prob_c = log_prob_idx.second; - size_t prefixes_len = std::min(prefixes.size(), beam_size); - - for (size_t i = 0; i < prefixes_len; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } + for (size_t time_step = 0; time_step < num_time_steps; time_step++) { + const auto& prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (init_ext_scorer_ != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + + if (num_prefixes == 0) { + continue; + } + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, init_ext_scorer_->beta); + + full_beam = (num_prefixes == beam_size); + } - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, - log_prob_c + - prefix->score); - continue; - } + vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // repeated character - if (c == prefix->character) { - // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, - log_prob_c + - prefix->log_prob_nb_prev); - } + // loop over chars + size_t log_prob_idx_len = log_prob_idx.size(); + for (size_t index = 0; index < log_prob_idx_len; index++) { + SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); + } + + prefixes.clear(); + + // update log probs + root->iterate_to_vec(prefixes); + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } // if + num_frame_decoded_++; + } // for probs_seq +} - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (init_ext_scorer_ != nullptr && - (c == space_id || init_ext_scorer_->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (init_ext_scorer_->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; +int32 CTCBeamSearch::SearchOneChar( + const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff) { + size_t beam_size = opts_.beam_size; + const auto& c = log_prob_idx.first; + const auto& log_prob_c = log_prob_idx.second; + size_t prefixes_len = std::min(prefixes.size(), beam_size); + + for (size_t i = 0; i < prefixes_len; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; } - float score = 0.0; - vector ngram; - ngram = init_ext_scorer_->make_ngram(prefix_to_score); - // lm score: p_{lm}(W)^{\alpha} + \beta - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - log_p += score; - log_p += init_ext_scorer_->beta; - } - // p_{nb}(l;x_{1:t}) - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, - log_p); - } - } // end of loop over prefix - return 0; + if (c == blank_id) { + prefix->log_prob_b_cur = + log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + + // repeated character + if (c == prefix->character) { + // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) + prefix->log_prob_nb_cur = log_sum_exp( + prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); + } + + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (init_ext_scorer_ != nullptr && + (c == space_id || init_ext_scorer_->is_character_based())) { + PathTrie* prefix_to_score = nullptr; + // skip scoring the space + if (init_ext_scorer_->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + vector ngram; + ngram = init_ext_scorer_->make_ngram(prefix_to_score); + // lm score: p_{lm}(W)^{\alpha} + \beta + score = init_ext_scorer_->get_log_cond_prob(ngram) * + init_ext_scorer_->alpha; + log_p += score; + log_p += init_ext_scorer_->beta; + } + // p_{nb}(l;x_{1:t}) + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + return 0; } void CTCBeamSearch::CalculateApproxScore() { - size_t beam_size = opts_.beam_size; - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (init_ext_scorer_ != nullptr) { - vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = init_ext_scorer_->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; - // remove language model weight: - approx_ctc -= - (init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha; + size_t beam_size = opts_.beam_size; + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (init_ext_scorer_ != nullptr) { + vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = init_ext_scorer_->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; + // remove language model weight: + approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * + init_ext_scorer_->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; } - prefixes[i]->approx_ctc = approx_ctc; - } } void CTCBeamSearch::LMRescore() { - size_t beam_size = opts_.beam_size; - if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - vector ngram = init_ext_scorer_->make_ngram(prefix); - score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha; - score += init_ext_scorer_->beta; - prefix->score += score; - } + size_t beam_size = opts_.beam_size; + if (init_ext_scorer_ != nullptr && + !init_ext_scorer_->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + vector ngram = init_ext_scorer_->make_ngram(prefix); + score = init_ext_scorer_->get_log_cond_prob(ngram) * + init_ext_scorer_->alpha; + score += init_ext_scorer_->beta; + prefix->score += score; + } + } } - } } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 53af449e..1e6ac88b 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -1,8 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "base/common.h" +#include "decoder/ctc_decoders/path_trie.h" +#include "decoder/ctc_decoders/scorer.h" #include "nnet/decodable-itf.h" #include "util/parse-options.h" -#include "decoder/ctc_decoders/scorer.h" -#include "decoder/ctc_decoders/path_trie.h" #pragma once @@ -17,63 +31,66 @@ struct CTCBeamSearchOptions { int beam_size; int cutoff_top_n; int num_proc_bsearch; - CTCBeamSearchOptions() : - dict_file("./model/words.txt"), - lm_path("./model/lm.arpa"), - alpha(1.9f), - beta(5.0), - beam_size(300), - cutoff_prob(0.99f), - cutoff_top_n(40), - num_proc_bsearch(0) { - } + CTCBeamSearchOptions() + : dict_file("./model/words.txt"), + lm_path("./model/lm.arpa"), + alpha(1.9f), + beta(5.0), + beam_size(300), + cutoff_prob(0.99f), + cutoff_top_n(40), + num_proc_bsearch(0) {} void Register(kaldi::OptionsItf* opts) { opts->Register("dict", &dict_file, "dict file "); opts->Register("lm-path", &lm_path, "language model file"); opts->Register("alpha", &alpha, "alpha"); opts->Register("beta", &beta, "beta"); - opts->Register("beam-size", &beam_size, "beam size for beam search method"); + opts->Register( + "beam-size", &beam_size, "beam size for beam search method"); opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs"); opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n"); - opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); + opts->Register( + "num-proc-bsearch", &num_proc_bsearch, "num proc bsearch"); } }; class CTCBeamSearch { - public: + public: explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); ~CTCBeamSearch() {} void InitDecoder(); void Decode(std::shared_ptr decodable); - std::string GetBestPath(); - std::vector> GetNBestPath(); + std::string GetBestPath(); + std::vector> GetNBestPath(); std::string GetFinalBestPath(); int NumFrameDecoded(); - int DecodeLikelihoods(const std::vector>&probs, + int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); - void AdvanceDecode(const std::shared_ptr& decodable, - int max_frames); + void AdvanceDecode( + const std::shared_ptr& decodable, + int max_frames); void Reset(); - private: - void ResetPrefixes(); - int32 SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff); - void CalculateApproxScore(); - void LMRescore(); - void AdvanceDecoding(const std::vector>& probs); - - CTCBeamSearchOptions opts_; - std::shared_ptr init_ext_scorer_; // todo separate later - //std::vector decoder_results_; - std::vector vocabulary_; // todo remove later - size_t blank_id; - int space_id; - std::shared_ptr root; - std::vector prefixes; - int num_frame_decoded_; - DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); + + private: + void ResetPrefixes(); + int32 SearchOneChar(const bool& full_beam, + const std::pair& log_prob_idx, + const BaseFloat& min_cutoff); + void CalculateApproxScore(); + void LMRescore(); + void AdvanceDecoding(const std::vector>& probs); + + CTCBeamSearchOptions opts_; + std::shared_ptr init_ext_scorer_; // todo separate later + // std::vector decoder_results_; + std::vector vocabulary_; // todo remove later + size_t blank_id; + int space_id; + std::shared_ptr root; + std::vector prefixes; + int num_frame_decoded_; + DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); }; -} // namespace basr \ No newline at end of file +} // namespace basr \ No newline at end of file diff --git a/speechx/speechx/frontend/fbank.h b/speechx/speechx/frontend/fbank.h index 6956690d..7d9cf422 100644 --- a/speechx/speechx/frontend/fbank.h +++ b/speechx/speechx/frontend/fbank.h @@ -22,15 +22,16 @@ namespace ppspeech { class FbankExtractor : FeatureExtractorInterface { public: - explicit FbankExtractor(const FbankOptions& opts, + explicit FbankExtractor(const FbankOptions& opts, share_ptr pre_extractor); - virtual void AcceptWaveform(const kaldi::Vector& input) = 0; + virtual void AcceptWaveform( + const kaldi::Vector& input) = 0; virtual void Read(kaldi::Vector* feat) = 0; virtual size_t Dim() const = 0; private: - bool Compute(const kaldi::Vector& wave, - kaldi::Vector* feat) const; + bool Compute(const kaldi::Vector& wave, + kaldi::Vector* feat) const; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/feature_extractor_controller.h b/speechx/speechx/frontend/feature_extractor_controller.h index e69de29b..5860f391 100644 --- a/speechx/speechx/frontend/feature_extractor_controller.h +++ b/speechx/speechx/frontend/feature_extractor_controller.h @@ -0,0 +1,14 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + diff --git a/speechx/speechx/frontend/feature_extractor_controller_impl.h b/speechx/speechx/frontend/feature_extractor_controller_impl.h index e69de29b..5860f391 100644 --- a/speechx/speechx/frontend/feature_extractor_controller_impl.h +++ b/speechx/speechx/frontend/feature_extractor_controller_impl.h @@ -0,0 +1,14 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + diff --git a/speechx/speechx/frontend/feature_extractor_interface.h b/speechx/speechx/frontend/feature_extractor_interface.h index 7395b792..e39f5e46 100644 --- a/speechx/speechx/frontend/feature_extractor_interface.h +++ b/speechx/speechx/frontend/feature_extractor_interface.h @@ -21,7 +21,8 @@ namespace ppspeech { class FeatureExtractorInterface { public: - virtual void AcceptWaveform(const kaldi::VectorBase& input) = 0; + virtual void AcceptWaveform( + const kaldi::VectorBase& input) = 0; virtual void Read(kaldi::VectorBase* feat) = 0; virtual size_t Dim() const = 0; }; diff --git a/speechx/speechx/frontend/linear_spectrogram.cc b/speechx/speechx/frontend/linear_spectrogram.cc index 8c20985d..6c008c39 100644 --- a/speechx/speechx/frontend/linear_spectrogram.cc +++ b/speechx/speechx/frontend/linear_spectrogram.cc @@ -25,97 +25,97 @@ using kaldi::VectorBase; using kaldi::Matrix; using std::vector; -//todo remove later +// todo remove later void CopyVector2StdVector_(const VectorBase& input, - vector* output) { - if (input.Dim() == 0) return; - output->resize(input.Dim()); - for (size_t idx = 0; idx < input.Dim(); ++idx) { - (*output)[idx] = input(idx); - } + vector* output) { + if (input.Dim() == 0) return; + output->resize(input.Dim()); + for (size_t idx = 0; idx < input.Dim(); ++idx) { + (*output)[idx] = input(idx); + } } void CopyStdVector2Vector_(const vector& input, - Vector* output) { - if (input.empty()) return; - output->Resize(input.size()); - for (size_t idx = 0; idx < input.size(); ++idx) { - (*output)(idx) = input[idx]; - } + Vector* output) { + if (input.empty()) return; + output->Resize(input.size()); + for (size_t idx = 0; idx < input.size(); ++idx) { + (*output)(idx) = input[idx]; + } } LinearSpectrogram::LinearSpectrogram( const LinearSpectrogramOptions& opts, std::unique_ptr base_extractor) { - opts_ = opts; - base_extractor_ = std::move(base_extractor); - int32 window_size = opts.frame_opts.WindowSize(); - int32 window_shift = opts.frame_opts.WindowShift(); - fft_points_ = window_size; - hanning_window_.resize(window_size); - - double a = M_2PI / (window_size - 1); - hanning_window_energy_ = 0; - for (int i = 0; i < window_size; ++i) { - hanning_window_[i] = 0.5 - 0.5 * cos(a * i); - hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; - } - - dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz + opts_ = opts; + base_extractor_ = std::move(base_extractor); + int32 window_size = opts.frame_opts.WindowSize(); + int32 window_shift = opts.frame_opts.WindowShift(); + fft_points_ = window_size; + hanning_window_.resize(window_size); + + double a = M_2PI / (window_size - 1); + hanning_window_energy_ = 0; + for (int i = 0; i < window_size; ++i) { + hanning_window_[i] = 0.5 - 0.5 * cos(a * i); + hanning_window_energy_ += hanning_window_[i] * hanning_window_[i]; + } + + dim_ = fft_points_ / 2 + 1; // the dimension is Fs/2 Hz } void LinearSpectrogram::AcceptWaveform(const VectorBase& input) { - base_extractor_->AcceptWaveform(input); + base_extractor_->AcceptWaveform(input); } void LinearSpectrogram::Hanning(vector* data) const { - CHECK_GE(data->size(), hanning_window_.size()); + CHECK_GE(data->size(), hanning_window_.size()); - for (size_t i = 0; i < hanning_window_.size(); ++i) { - data->at(i) *= hanning_window_[i]; - } + for (size_t i = 0; i < hanning_window_.size(); ++i) { + data->at(i) *= hanning_window_[i]; + } } bool LinearSpectrogram::NumpyFft(vector* v, vector* real, vector* img) const { - Vector v_tmp; - CopyStdVector2Vector_(*v, &v_tmp); - RealFft(&v_tmp, true); - CopyVector2StdVector_(v_tmp, v); - real->push_back(v->at(0)); - img->push_back(0); - for (int i = 1; i < v->size() / 2; i++) { - real->push_back(v->at(2 * i)); - img->push_back(v->at(2 * i + 1)); - } - real->push_back(v->at(1)); - img->push_back(0); - - return true; + Vector v_tmp; + CopyStdVector2Vector_(*v, &v_tmp); + RealFft(&v_tmp, true); + CopyVector2StdVector_(v_tmp, v); + real->push_back(v->at(0)); + img->push_back(0); + for (int i = 1; i < v->size() / 2; i++) { + real->push_back(v->at(2 * i)); + img->push_back(v->at(2 * i + 1)); + } + real->push_back(v->at(1)); + img->push_back(0); + + return true; } // todo remove later void LinearSpectrogram::ReadFeats(Matrix* feats) { - Vector tmp; - waveform_.Resize(base_extractor_->Dim()); - Compute(tmp, &waveform_); - vector> result; - vector feats_vec; - CopyVector2StdVector_(waveform_, &feats_vec); - Compute(feats_vec, result); - feats->Resize(result.size(), result[0].size()); - for (int row_idx = 0; row_idx < result.size(); ++row_idx) { - for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) { - (*feats)(row_idx, col_idx) = result[row_idx][col_idx]; + Vector tmp; + waveform_.Resize(base_extractor_->Dim()); + Compute(tmp, &waveform_); + vector> result; + vector feats_vec; + CopyVector2StdVector_(waveform_, &feats_vec); + Compute(feats_vec, result); + feats->Resize(result.size(), result[0].size()); + for (int row_idx = 0; row_idx < result.size(); ++row_idx) { + for (int col_idx = 0; col_idx < result[0].size(); ++col_idx) { + (*feats)(row_idx, col_idx) = result[row_idx][col_idx]; + } } - } - waveform_.Resize(0); + waveform_.Resize(0); } void LinearSpectrogram::Read(VectorBase* feat) { - // todo - return; + // todo + return; } // only for test, remove later @@ -129,49 +129,49 @@ void LinearSpectrogram::Compute(const VectorBase& input, // todo: refactor later (SmileGoat) bool LinearSpectrogram::Compute(const vector& wave, vector>& feat) { - int num_samples = wave.size(); - const int& frame_length = opts_.frame_opts.WindowSize(); - const int& sample_rate = opts_.frame_opts.samp_freq; - const int& frame_shift = opts_.frame_opts.WindowShift(); - const int& fft_points = fft_points_; - const float scale = hanning_window_energy_ * sample_rate; - - if (num_samples < frame_length) { - return true; - } - - int num_frames = 1 + ((num_samples - frame_length) / frame_shift); - feat.resize(num_frames); - vector fft_real((fft_points_ / 2 + 1), 0); - vector fft_img((fft_points_ / 2 + 1), 0); - vector v(frame_length, 0); - vector power((fft_points / 2 + 1)); - - for (int i = 0; i < num_frames; ++i) { - vector data(wave.data() + i * frame_shift, - wave.data() + i * frame_shift + frame_length); - Hanning(&data); - fft_img.clear(); - fft_real.clear(); - v.assign(data.begin(), data.end()); - NumpyFft(&v, &fft_real, &fft_img); - - feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz - for (int j = 0; j < (fft_points / 2 + 1); ++j) { - power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; - feat[i][j] = power[j]; - - if (j == 0 || j == feat[0].size() - 1) { - feat[i][j] /= scale; - } else { - feat[i][j] *= (2.0 / scale); - } - - // log added eps=1e-14 - feat[i][j] = std::log(feat[i][j] + 1e-14); + int num_samples = wave.size(); + const int& frame_length = opts_.frame_opts.WindowSize(); + const int& sample_rate = opts_.frame_opts.samp_freq; + const int& frame_shift = opts_.frame_opts.WindowShift(); + const int& fft_points = fft_points_; + const float scale = hanning_window_energy_ * sample_rate; + + if (num_samples < frame_length) { + return true; + } + + int num_frames = 1 + ((num_samples - frame_length) / frame_shift); + feat.resize(num_frames); + vector fft_real((fft_points_ / 2 + 1), 0); + vector fft_img((fft_points_ / 2 + 1), 0); + vector v(frame_length, 0); + vector power((fft_points / 2 + 1)); + + for (int i = 0; i < num_frames; ++i) { + vector data(wave.data() + i * frame_shift, + wave.data() + i * frame_shift + frame_length); + Hanning(&data); + fft_img.clear(); + fft_real.clear(); + v.assign(data.begin(), data.end()); + NumpyFft(&v, &fft_real, &fft_img); + + feat[i].resize(fft_points / 2 + 1); // the last dimension is Fs/2 Hz + for (int j = 0; j < (fft_points / 2 + 1); ++j) { + power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; + feat[i][j] = power[j]; + + if (j == 0 || j == feat[0].size() - 1) { + feat[i][j] /= scale; + } else { + feat[i][j] *= (2.0 / scale); + } + + // log added eps=1e-14 + feat[i][j] = std::log(feat[i][j] + 1e-14); + } } - } - return true; + return true; } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/linear_spectrogram.h b/speechx/speechx/frontend/linear_spectrogram.h index 0923acee..20b5e4b5 100644 --- a/speechx/speechx/frontend/linear_spectrogram.h +++ b/speechx/speechx/frontend/linear_spectrogram.h @@ -1,32 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once +#include "base/common.h" #include "frontend/feature_extractor_interface.h" #include "kaldi/feat/feature-window.h" -#include "base/common.h" namespace ppspeech { struct LinearSpectrogramOptions { kaldi::FrameExtractionOptions frame_opts; - LinearSpectrogramOptions(): - frame_opts() {} + LinearSpectrogramOptions() : frame_opts() {} - void Register(kaldi::OptionsItf* opts) { - frame_opts.Register(opts); - } + void Register(kaldi::OptionsItf* opts) { frame_opts.Register(opts); } }; class LinearSpectrogram : public FeatureExtractorInterface { public: - explicit LinearSpectrogram(const LinearSpectrogramOptions& opts, - std::unique_ptr base_extractor); - virtual void AcceptWaveform(const kaldi::VectorBase& input); + explicit LinearSpectrogram( + const LinearSpectrogramOptions& opts, + std::unique_ptr base_extractor); + virtual void AcceptWaveform( + const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return dim_; } void ReadFeats(kaldi::Matrix* feats); - private: + private: void Hanning(std::vector* data) const; bool Compute(const std::vector& wave, std::vector>& feat); @@ -41,7 +54,7 @@ class LinearSpectrogram : public FeatureExtractorInterface { std::vector hanning_window_; kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; - kaldi::Vector waveform_; // remove later, todo(SmileGoat) + kaldi::Vector waveform_; // remove later, todo(SmileGoat) std::unique_ptr base_extractor_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/frontend/normalizer.cc b/speechx/speechx/frontend/normalizer.cc index 16fc09a8..abf798e5 100644 --- a/speechx/speechx/frontend/normalizer.cc +++ b/speechx/speechx/frontend/normalizer.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "frontend/normalizer.h" #include "kaldi/feat/cmvn.h" @@ -12,169 +26,173 @@ using std::vector; using kaldi::SubVector; DecibelNormalizer::DecibelNormalizer(const DecibelNormalizerOptions& opts) { - opts_ = opts; - dim_ = 0; + opts_ = opts; + dim_ = 0; } - -void DecibelNormalizer::AcceptWaveform(const kaldi::VectorBase& input) { - dim_ = input.Dim(); - waveform_.Resize(input.Dim()); - waveform_.CopyFromVec(input); + +void DecibelNormalizer::AcceptWaveform( + const kaldi::VectorBase& input) { + dim_ = input.Dim(); + waveform_.Resize(input.Dim()); + waveform_.CopyFromVec(input); } void DecibelNormalizer::Read(kaldi::VectorBase* feat) { - if (waveform_.Dim() == 0) return; - Compute(waveform_, feat); + if (waveform_.Dim() == 0) return; + Compute(waveform_, feat); } -//todo remove later +// todo remove later void CopyVector2StdVector(const kaldi::VectorBase& input, vector* output) { - if (input.Dim() == 0) return; - output->resize(input.Dim()); - for (size_t idx = 0; idx < input.Dim(); ++idx) { - (*output)[idx] = input(idx); - } + if (input.Dim() == 0) return; + output->resize(input.Dim()); + for (size_t idx = 0; idx < input.Dim(); ++idx) { + (*output)[idx] = input(idx); + } } void CopyStdVector2Vector(const vector& input, VectorBase* output) { - if (input.empty()) return; - assert(input.size() == output->Dim()); - for (size_t idx = 0; idx < input.size(); ++idx) { - (*output)(idx) = input[idx]; - } + if (input.empty()) return; + assert(input.size() == output->Dim()); + for (size_t idx = 0; idx < input.size(); ++idx) { + (*output)(idx) = input[idx]; + } } bool DecibelNormalizer::Compute(const VectorBase& input, VectorBase* feat) const { - // calculate db rms - BaseFloat rms_db = 0.0; - BaseFloat mean_square = 0.0; - BaseFloat gain = 0.0; - BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); - - vector samples; - samples.resize(input.Dim()); - for (int32 i = 0; i < samples.size(); ++i) { - samples[i] = input(i); - } - - // square - for (auto &d : samples) { - if (opts_.convert_int_float) { - d = d * wave_float_normlization; + // calculate db rms + BaseFloat rms_db = 0.0; + BaseFloat mean_square = 0.0; + BaseFloat gain = 0.0; + BaseFloat wave_float_normlization = 1.0f / (std::pow(2, 16 - 1)); + + vector samples; + samples.resize(input.Dim()); + for (int32 i = 0; i < samples.size(); ++i) { + samples[i] = input(i); } - mean_square += d * d; - } - - // mean - mean_square /= samples.size(); - rms_db = 10 * std::log10(mean_square); - gain = opts_.target_db - rms_db; - - if (gain > opts_.max_gain_db) { - LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB," - << "because the the probable gain have exceeds opts_.max_gain_db" - << opts_.max_gain_db << "dB."; - return false; - } - - // Note that this is an in-place transformation. - for (auto &item : samples) { - // python item *= 10.0 ** (gain / 20.0) - item *= std::pow(10.0, gain / 20.0); - } - - CopyStdVector2Vector(samples, feat); - return true; + + // square + for (auto& d : samples) { + if (opts_.convert_int_float) { + d = d * wave_float_normlization; + } + mean_square += d * d; + } + + // mean + mean_square /= samples.size(); + rms_db = 10 * std::log10(mean_square); + gain = opts_.target_db - rms_db; + + if (gain > opts_.max_gain_db) { + LOG(ERROR) + << "Unable to normalize segment to " << opts_.target_db << "dB," + << "because the the probable gain have exceeds opts_.max_gain_db" + << opts_.max_gain_db << "dB."; + return false; + } + + // Note that this is an in-place transformation. + for (auto& item : samples) { + // python item *= 10.0 ** (gain / 20.0) + item *= std::pow(10.0, gain / 20.0); + } + + CopyStdVector2Vector(samples, feat); + return true; } CMVN::CMVN(std::string cmvn_file) : var_norm_(true) { - bool binary; - kaldi::Input ki(cmvn_file, &binary); - stats_.Read(ki.Stream(), binary); + bool binary; + kaldi::Input ki(cmvn_file, &binary); + stats_.Read(ki.Stream(), binary); } void CMVN::AcceptWaveform(const kaldi::VectorBase& input) { - return; + return; } -void CMVN::Read(kaldi::VectorBase* feat) { - return; -} +void CMVN::Read(kaldi::VectorBase* feat) { return; } // feats contain num_frames feature. void CMVN::ApplyCMVN(bool var_norm, VectorBase* feats) { - KALDI_ASSERT(feats != NULL); - int32 dim = stats_.NumCols() - 1; - if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || feats->Dim() % dim != 0) { - KALDI_ERR << "Dim mismatch: cmvn " - << stats_.NumRows() << 'x' << stats_.NumCols() - << ", feats " << feats->Dim() << 'x'; - } - if (stats_.NumRows() == 1 && var_norm) { - KALDI_ERR << "You requested variance normalization but no variance stats_ " - << "are supplied."; - } - - double count = stats_(0, dim); - // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when - // computing an offset and representing it as stats_, we use a count of one. - if (count < 1.0) - KALDI_ERR << "Insufficient stats_ for cepstral mean and variance normalization: " - << "count = " << count; - - if (!var_norm) { - Vector offset(feats->Dim()); - SubVector mean_stats(stats_.RowData(0), dim); - Vector mean_stats_apply(feats->Dim()); - //fill the datat of mean_stats in mean_stats_appy whose dim is equal with the dim of feature. - //the dim of feats = dim * num_frames; - for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { - SubVector stats_tmp(mean_stats_apply.Data() + dim*idx, dim); - stats_tmp.CopyFromVec(mean_stats); + KALDI_ASSERT(feats != NULL); + int32 dim = stats_.NumCols() - 1; + if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || + feats->Dim() % dim != 0) { + KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x' + << stats_.NumCols() << ", feats " << feats->Dim() << 'x'; } - offset.AddVec(-1.0 / count, mean_stats_apply); - feats->AddVec(1.0, offset); - return; - } - // norm(0, d) = mean offset; - // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). - kaldi::Matrix norm(2, feats->Dim()); - for (int32 d = 0; d < dim; d++) { - double mean, offset, scale; - mean = stats_(0, d)/count; - double var = (stats_(1, d)/count) - mean*mean, - floor = 1.0e-20; - if (var < floor) { - KALDI_WARN << "Flooring cepstral variance from " << var << " to " - << floor; - var = floor; + if (stats_.NumRows() == 1 && var_norm) { + KALDI_ERR + << "You requested variance normalization but no variance stats_ " + << "are supplied."; + } + + double count = stats_(0, dim); + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats_, we use a count of one. + if (count < 1.0) + KALDI_ERR << "Insufficient stats_ for cepstral mean and variance " + "normalization: " + << "count = " << count; + + if (!var_norm) { + Vector offset(feats->Dim()); + SubVector mean_stats(stats_.RowData(0), dim); + Vector mean_stats_apply(feats->Dim()); + // fill the datat of mean_stats in mean_stats_appy whose dim is equal + // with the dim of feature. + // the dim of feats = dim * num_frames; + for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { + SubVector stats_tmp(mean_stats_apply.Data() + dim * idx, + dim); + stats_tmp.CopyFromVec(mean_stats); + } + offset.AddVec(-1.0 / count, mean_stats_apply); + feats->AddVec(1.0, offset); + return; } - scale = 1.0 / sqrt(var); - if (scale != scale || 1/scale == 0.0) - KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; - offset = -(mean*scale); - for (int32 d_skip = d; d_skip < feats->Dim();) { - norm(0, d_skip) = offset; - norm(1, d_skip) = scale; - d_skip = d_skip + dim; + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + kaldi::Matrix norm(2, feats->Dim()); + for (int32 d = 0; d < dim; d++) { + double mean, offset, scale; + mean = stats_(0, d) / count; + double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; + if (var < floor) { + KALDI_WARN << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1 / scale == 0.0) + KALDI_ERR + << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean * scale); + for (int32 d_skip = d; d_skip < feats->Dim();) { + norm(0, d_skip) = offset; + norm(1, d_skip) = scale; + d_skip = d_skip + dim; + } } - } - // Apply the normalization. - feats->MulElements(norm.Row(1)); - feats->AddVec(1.0, norm.Row(0)); + // Apply the normalization. + feats->MulElements(norm.Row(1)); + feats->AddVec(1.0, norm.Row(0)); } void CMVN::ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase* feats) { - ApplyCmvn(stats_, var_norm, feats); + ApplyCmvn(stats_, var_norm, feats); } bool CMVN::Compute(const VectorBase& input, VectorBase* feat) const { - return false; + return false; } -} // namespace ppspeech +} // namespace ppspeech diff --git a/speechx/speechx/frontend/normalizer.h b/speechx/speechx/frontend/normalizer.h index eea03fc1..6af5cdd8 100644 --- a/speechx/speechx/frontend/normalizer.h +++ b/speechx/speechx/frontend/normalizer.h @@ -1,37 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "base/common.h" #include "frontend/feature_extractor_interface.h" -#include "kaldi/util/options-itf.h" #include "kaldi/matrix/kaldi-matrix.h" +#include "kaldi/util/options-itf.h" namespace ppspeech { struct DecibelNormalizerOptions { - float target_db; - float max_gain_db; - bool convert_int_float; - DecibelNormalizerOptions() : - target_db(-20), - max_gain_db(300.0), - convert_int_float(false) {} + float target_db; + float max_gain_db; + bool convert_int_float; + DecibelNormalizerOptions() + : target_db(-20), max_gain_db(300.0), convert_int_float(false) {} void Register(kaldi::OptionsItf* opts) { - opts->Register("target-db", &target_db, "target db for db normalization"); - opts->Register("max-gain-db", &max_gain_db, "max gain db for db normalization"); - opts->Register("convert-int-float", &convert_int_float, "if convert int samples to float"); + opts->Register( + "target-db", &target_db, "target db for db normalization"); + opts->Register( + "max-gain-db", &max_gain_db, "max gain db for db normalization"); + opts->Register("convert-int-float", + &convert_int_float, + "if convert int samples to float"); } }; class DecibelNormalizer : public FeatureExtractorInterface { public: explicit DecibelNormalizer(const DecibelNormalizerOptions& opts); - virtual void AcceptWaveform(const kaldi::VectorBase& input); + virtual void AcceptWaveform( + const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return dim_; } bool Compute(const kaldi::VectorBase& input, kaldi::VectorBase* feat) const; + private: DecibelNormalizerOptions opts_; size_t dim_; @@ -43,7 +61,8 @@ class DecibelNormalizer : public FeatureExtractorInterface { class CMVN : public FeatureExtractorInterface { public: explicit CMVN(std::string cmvn_file); - virtual void AcceptWaveform(const kaldi::VectorBase& input); + virtual void AcceptWaveform( + const kaldi::VectorBase& input); virtual void Read(kaldi::VectorBase* feat); virtual size_t Dim() const { return stats_.NumCols() - 1; } bool Compute(const kaldi::VectorBase& input, @@ -51,6 +70,7 @@ class CMVN : public FeatureExtractorInterface { // for test void ApplyCMVN(bool var_norm, kaldi::VectorBase* feats); void ApplyCMVNMatrix(bool var_norm, kaldi::MatrixBase* feats); + private: kaldi::Matrix stats_; std::shared_ptr base_extractor_; diff --git a/speechx/speechx/frontend/window.h b/speechx/speechx/frontend/window.h index 5303cad8..70d6307e 100644 --- a/speechx/speechx/frontend/window.h +++ b/speechx/speechx/frontend/window.h @@ -13,4 +13,3 @@ // limitations under the License. // extract the window of kaldi feat. - diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/nnet/decodable-itf.h index 5f641b6c..3ea9b557 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/nnet/decodable-itf.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // itf/decodable-itf.h // Copyright 2009-2011 Microsoft Corporation; Saarland University; @@ -42,8 +56,10 @@ namespace kaldi { For online decoding, where the features are coming in in real time, it is important to understand the IsLastFrame() and NumFramesReady() functions. - There are two ways these are used: the old online-decoding code, in ../online/, - and the new online-decoding code, in ../online2/. In the old online-decoding + There are two ways these are used: the old online-decoding code, in + ../online/, + and the new online-decoding code, in ../online2/. In the old + online-decoding code, the decoder would do: \code{.cc} for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { @@ -52,13 +68,16 @@ namespace kaldi { \endcode and the call to IsLastFrame would block if the features had not arrived yet. The decodable object would have to know when to terminate the decoding. This - online-decoding mode is still supported, it is what happens when you call, for + online-decoding mode is still supported, it is what happens when you call, + for example, LatticeFasterDecoder::Decode(). We realized that this "blocking" mode of decoding is not very convenient because it forces the program to be multi-threaded and makes it complex to - control endpointing. In the "new" decoding code, you don't call (for example) - LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), + control endpointing. In the "new" decoding code, you don't call (for + example) + LatticeFasterDecoder::Decode(), you call + LatticeFasterDecoder::InitDecoding(), and then each time you get more features, you provide them to the decodable object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does something like this: @@ -68,7 +87,8 @@ namespace kaldi { } \endcode So the decodable object never has IsLastFrame() called. For decoding where - you are starting with a matrix of features, the NumFramesReady() function will + you are starting with a matrix of features, the NumFramesReady() function + will always just return the number of frames in the file, and IsLastFrame() will return true for the last frame. @@ -80,43 +100,52 @@ namespace kaldi { frame of the file once we've decided to terminate decoding. */ class DecodableInterface { - public: - /// Returns the log likelihood, which will be negated in the decoder. - /// The "frame" starts from zero. You should verify that NumFramesReady() > frame - /// before calling this. - virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; - - /// Returns true if this is the last frame. Frames are zero-based, so the - /// first frame is zero. IsLastFrame(-1) will return false, unless the file - /// is empty (which is a case that I'm not sure all the code will handle, so - /// be careful). Caution: the behavior of this function in an online setting - /// is being changed somewhat. In future it may return false in cases where - /// we haven't yet decided to terminate decoding, but later true if we decide - /// to terminate decoding. The plan in future is to rely more on - /// NumFramesReady(), and in future, IsLastFrame() would always return false - /// in an online-decoding setting, and would only return true in a - /// decoding-from-matrix setting where we want to allow the last delta or LDA - /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; - - /// The call NumFramesReady() will return the number of frames currently available - /// for this decodable object. This is for use in setups where you don't want the - /// decoder to block while waiting for input. This is newly added as of Jan 2014, - /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to - /// know when to stop decoding. - virtual int32 NumFramesReady() const { - KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; - return -1; - } - - /// Returns the number of states in the acoustic model - /// (they will be indexed one-based, i.e. from 1 to NumIndices(); - /// this is for compatibility with OpenFst). - virtual int32 NumIndices() const = 0; - - virtual std::vector FrameLogLikelihood(int32 frame) = 0; - - virtual ~DecodableInterface() {} + public: + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that NumFramesReady() > + /// frame + /// before calling this. + virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online + /// setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we + /// decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or + /// LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32 frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently + /// available + /// for this decodable object. This is for use in setups where you don't + /// want the + /// decoder to block while waiting for input. This is newly added as of Jan + /// 2014, + /// and I hope, going forward, to rely on this mechanism more than + /// IsLastFrame to + /// know when to stop decoding. + virtual int32 NumFramesReady() const { + KALDI_ERR + << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst). + virtual int32 NumIndices() const = 0; + + virtual std::vector FrameLogLikelihood(int32 frame) = 0; + + virtual ~DecodableInterface() {} }; /// @} } // namespace Kaldi diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 45486bc0..d92f4fd3 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "nnet/decodable.h" namespace ppspeech { @@ -5,51 +19,43 @@ namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; -Decodable::Decodable(const std::shared_ptr& nnet): - frontend_(NULL), - nnet_(nnet), - finished_(false), - frames_ready_(0) { -} +Decodable::Decodable(const std::shared_ptr& nnet) + : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { - frames_ready_ += likelihood.NumRows(); + frames_ready_ += likelihood.NumRows(); } -//Decodable::Init(DecodableConfig config) { +// Decodable::Init(DecodableConfig config) { //} bool Decodable::IsLastFrame(int32 frame) const { - CHECK_LE(frame, frames_ready_); - return finished_ && (frame == frames_ready_ - 1); + CHECK_LE(frame, frames_ready_); + return finished_ && (frame == frames_ready_ - 1); } -int32 Decodable::NumIndices() const { - return 0; -} +int32 Decodable::NumIndices() const { return 0; } -BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { - return 0; -} +BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; } void Decodable::FeedFeatures(const Matrix& features) { - nnet_->FeedForward(features, &nnet_cache_); - frames_ready_ += nnet_cache_.NumRows(); - return ; + nnet_->FeedForward(features, &nnet_cache_); + frames_ready_ += nnet_cache_.NumRows(); + return; } std::vector Decodable::FrameLogLikelihood(int32 frame) { - std::vector result; - result.reserve(nnet_cache_.NumCols()); - for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { - result[idx] = nnet_cache_(frame, idx); - } - return result; + std::vector result; + result.reserve(nnet_cache_.NumCols()); + for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { + result[idx] = nnet_cache_(frame, idx); + } + return result; } void Decodable::Reset() { - // frontend_.Reset(); - nnet_->Reset(); + // frontend_.Reset(); + nnet_->Reset(); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index eb9cddb9..5a59d6ab 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -1,7 +1,21 @@ -#include "nnet/decodable-itf.h" +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" #include "frontend/feature_extractor_interface.h" +#include "kaldi/matrix/kaldi-matrix.h" +#include "nnet/decodable-itf.h" #include "nnet/nnet_interface.h" namespace ppspeech { @@ -9,17 +23,20 @@ namespace ppspeech { struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { - public: + public: explicit Decodable(const std::shared_ptr& nnet); - //void Init(DecodableOpts config); + // void Init(DecodableOpts config); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame) const; virtual int32 NumIndices() const; virtual std::vector FrameLogLikelihood(int32 frame); - void Acceptlikelihood(const kaldi::Matrix& likelihood); // remove later - void FeedFeatures(const kaldi::Matrix& feature); // only for test, todo remove later + void Acceptlikelihood( + const kaldi::Matrix& likelihood); // remove later + void FeedFeatures(const kaldi::Matrix& + feature); // only for test, todo remove later void Reset(); void InputFinished() { finished_ = true; } + private: std::shared_ptr frontend_; std::shared_ptr nnet_; diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index cdc0a6f2..fe669f0a 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once @@ -10,10 +24,9 @@ namespace ppspeech { class NnetInterface { public: virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences)= 0; + kaldi::Matrix* inferences) = 0; virtual void Reset() = 0; virtual ~NnetInterface() {} - }; } // namespace ppspeech diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index c45bb75e..5dea4e51 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "nnet/paddle_nnet.h" #include "absl/strings/str_split.h" @@ -9,43 +23,44 @@ using std::shared_ptr; using kaldi::Matrix; void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { - std::vector cache_names; - cache_names = absl::StrSplit(opts.cache_names, ","); - std::vector cache_shapes; - cache_shapes = absl::StrSplit(opts.cache_shape, ","); - assert(cache_shapes.size() == cache_names.size()); - - cache_encouts_.clear(); - cache_names_idx_.clear(); - for (size_t i = 0; i < cache_shapes.size(); i++) { - std::vector tmp_shape; - tmp_shape = absl::StrSplit(cache_shapes[i], "-"); - std::vector cur_shape; - std::transform(tmp_shape.begin(), tmp_shape.end(), - std::back_inserter(cur_shape), - [](const std::string& s) { - return atoi(s.c_str()); - }); - cache_names_idx_[cache_names[i]] = i; - std::shared_ptr> cache_eout = std::make_shared>(cur_shape); - cache_encouts_.push_back(cache_eout); - } + std::vector cache_names; + cache_names = absl::StrSplit(opts.cache_names, ","); + std::vector cache_shapes; + cache_shapes = absl::StrSplit(opts.cache_shape, ","); + assert(cache_shapes.size() == cache_names.size()); + + cache_encouts_.clear(); + cache_names_idx_.clear(); + for (size_t i = 0; i < cache_shapes.size(); i++) { + std::vector tmp_shape; + tmp_shape = absl::StrSplit(cache_shapes[i], "-"); + std::vector cur_shape; + std::transform(tmp_shape.begin(), + tmp_shape.end(), + std::back_inserter(cur_shape), + [](const std::string& s) { return atoi(s.c_str()); }); + cache_names_idx_[cache_names[i]] = i; + std::shared_ptr> cache_eout = + std::make_shared>(cur_shape); + cache_encouts_.push_back(cache_eout); + } } -PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { +PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { paddle_infer::Config config; config.SetModel(opts.model_path, opts.params_path); if (opts.use_gpu) { - config.EnableUseGpu(500, 0); + config.EnableUseGpu(500, 0); } config.SwitchIrOptim(opts.switch_ir_optim); if (opts.enable_fc_padding == false) { - config.DisableFCPadding(); + config.DisableFCPadding(); } if (opts.enable_profile) { - config.EnableProfile(); + config.EnableProfile(); } - pool.reset(new paddle_infer::services::PredictorPool(config, opts.thread_num)); + pool.reset( + new paddle_infer::services::PredictorPool(config, opts.thread_num)); if (pool == nullptr) { LOG(ERROR) << "create the predictor pool failed"; } @@ -59,7 +74,7 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { vector input_names_vec = absl::StrSplit(opts.input_names, ","); vector output_names_vec = absl::StrSplit(opts.output_names, ","); paddle_infer::Predictor* predictor = GetPredictor(); - + std::vector model_input_names = predictor->GetInputNames(); assert(input_names_vec.size() == model_input_names.size()); for (size_t i = 0; i < model_input_names.size(); i++) { @@ -68,16 +83,14 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts):opts_(opts) { std::vector model_output_names = predictor->GetOutputNames(); assert(output_names_vec.size() == model_output_names.size()); - for (size_t i = 0;i < output_names_vec.size(); i++) { + for (size_t i = 0; i < output_names_vec.size(); i++) { assert(output_names_vec[i] == model_output_names[i]); } ReleasePredictor(predictor); InitCacheEncouts(opts); } -void PaddleNnet::Reset() { - InitCacheEncouts(opts_); -} +void PaddleNnet::Reset() { InitCacheEncouts(opts_); } paddle_infer::Predictor* PaddleNnet::GetPredictor() { LOG(INFO) << "attempt to get a new predictor instance " << std::endl; @@ -122,80 +135,88 @@ int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { } shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { - auto iter = cache_names_idx_.find(name); - if (iter == cache_names_idx_.end()) { - return nullptr; - } - assert(iter->second < cache_encouts_.size()); - return cache_encouts_[iter->second]; + auto iter = cache_names_idx_.find(name); + if (iter == cache_names_idx_.end()) { + return nullptr; + } + assert(iter->second < cache_encouts_.size()); + return cache_encouts_[iter->second]; } -void PaddleNnet::FeedForward(const Matrix& features, Matrix* inferences) { - paddle_infer::Predictor* predictor = GetPredictor(); - int row = features.NumRows(); - int col = features.NumCols(); - std::vector feed_feature; - // todo refactor feed feature: SmileGoat - feed_feature.reserve(row*col); - for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { - for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { - feed_feature.push_back(features(row_idx, col_idx)); +void PaddleNnet::FeedForward(const Matrix& features, + Matrix* inferences) { + paddle_infer::Predictor* predictor = GetPredictor(); + int row = features.NumRows(); + int col = features.NumCols(); + std::vector feed_feature; + // todo refactor feed feature: SmileGoat + feed_feature.reserve(row * col); + for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { + for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { + feed_feature.push_back(features(row_idx, col_idx)); + } + } + std::vector input_names = predictor->GetInputNames(); + std::vector output_names = predictor->GetOutputNames(); + LOG(INFO) << "feat info: row=" << row << ", col= " << col; + + std::unique_ptr input_tensor = + predictor->GetInputHandle(input_names[0]); + std::vector INPUT_SHAPE = {1, row, col}; + input_tensor->Reshape(INPUT_SHAPE); + input_tensor->CopyFromCpu(feed_feature.data()); + std::unique_ptr input_len = + predictor->GetInputHandle(input_names[1]); + std::vector input_len_size = {1}; + input_len->Reshape(input_len_size); + std::vector audio_len; + audio_len.push_back(row); + input_len->CopyFromCpu(audio_len.data()); + + std::unique_ptr h_box = + predictor->GetInputHandle(input_names[2]); + shared_ptr> h_cache = GetCacheEncoder(input_names[2]); + h_box->Reshape(h_cache->get_shape()); + h_box->CopyFromCpu(h_cache->get_data().data()); + std::unique_ptr c_box = + predictor->GetInputHandle(input_names[3]); + shared_ptr> c_cache = GetCacheEncoder(input_names[3]); + c_box->Reshape(c_cache->get_shape()); + c_box->CopyFromCpu(c_cache->get_data().data()); + bool success = predictor->Run(); + + if (success == false) { + LOG(INFO) << "predictor run occurs error"; } - } - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - LOG(INFO) << "feat info: row=" << row << ", col= " << col; - - std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(feed_feature.data()); - std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(row); - input_len->CopyFromCpu(audio_len.data()); - - std::unique_ptr h_box = predictor->GetInputHandle(input_names[2]); - shared_ptr> h_cache = GetCacheEncoder(input_names[2]); - h_box->Reshape(h_cache->get_shape()); - h_box->CopyFromCpu(h_cache->get_data().data()); - std::unique_ptr c_box = predictor->GetInputHandle(input_names[3]); - shared_ptr> c_cache = GetCacheEncoder(input_names[3]); - c_box->Reshape(c_cache->get_shape()); - c_box->CopyFromCpu(c_cache->get_data().data()); - bool success = predictor->Run(); - - if (success == false) { - LOG(INFO) << "predictor run occurs error"; - } - - LOG(INFO) << "get the model success"; - std::unique_ptr h_out = predictor->GetOutputHandle(output_names[2]); - assert(h_cache->get_shape() == h_out->shape()); - h_out->CopyToCpu(h_cache->get_data().data()); - std::unique_ptr c_out = predictor->GetOutputHandle(output_names[3]); - assert(c_cache->get_shape() == c_out->shape()); - c_out->CopyToCpu(c_cache->get_data().data()); - - // get result - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - row = output_shape[1]; - col = output_shape[2]; - vector inferences_result; - inferences->Resize(row, col); - inferences_result.resize(row*col); - output_tensor->CopyToCpu(inferences_result.data()); - ReleasePredictor(predictor); - - for (int row_idx = 0; row_idx < row; ++row_idx) { - for (int col_idx = 0; col_idx < col; ++col_idx) { - (*inferences)(row_idx, col_idx) = inferences_result[col*row_idx + col_idx]; + + LOG(INFO) << "get the model success"; + std::unique_ptr h_out = + predictor->GetOutputHandle(output_names[2]); + assert(h_cache->get_shape() == h_out->shape()); + h_out->CopyToCpu(h_cache->get_data().data()); + std::unique_ptr c_out = + predictor->GetOutputHandle(output_names[3]); + assert(c_cache->get_shape() == c_out->shape()); + c_out->CopyToCpu(c_cache->get_data().data()); + + // get result + std::unique_ptr output_tensor = + predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + row = output_shape[1]; + col = output_shape[2]; + vector inferences_result; + inferences->Resize(row, col); + inferences_result.resize(row * col); + output_tensor->CopyToCpu(inferences_result.data()); + ReleasePredictor(predictor); + + for (int row_idx = 0; row_idx < row; ++row_idx) { + for (int col_idx = 0; col_idx < col; ++col_idx) { + (*inferences)(row_idx, col_idx) = + inferences_result[col * row_idx + col_idx]; + } } - } } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index 8eaced07..aec27fd1 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -1,8 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once -#include "nnet/nnet_interface.h" #include "base/common.h" +#include "nnet/nnet_interface.h" #include "paddle_inference_api.h" #include "kaldi/matrix/kaldi-matrix.h" @@ -13,71 +27,79 @@ namespace ppspeech { struct ModelOptions { - std::string model_path; - std::string params_path; - int thread_num; - bool use_gpu; - bool switch_ir_optim; - std::string input_names; - std::string output_names; - std::string cache_names; - std::string cache_shape; - bool enable_fc_padding; - bool enable_profile; - ModelOptions() : - model_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdmodel"), - params_path("../../../../model/paddle_online_deepspeech/model/avg_1.jit.pdiparams"), - thread_num(2), - use_gpu(false), - input_names("audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box"), - output_names("save_infer_model/scale_0.tmp_1,save_infer_model/scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/scale_3.tmp_1"), - cache_names("chunk_state_h_box,chunk_state_c_box"), - cache_shape("3-1-1024,3-1-1024"), - switch_ir_optim(false), - enable_fc_padding(false), - enable_profile(false) { - } + std::string model_path; + std::string params_path; + int thread_num; + bool use_gpu; + bool switch_ir_optim; + std::string input_names; + std::string output_names; + std::string cache_names; + std::string cache_shape; + bool enable_fc_padding; + bool enable_profile; + ModelOptions() + : model_path( + "../../../../model/paddle_online_deepspeech/model/" + "avg_1.jit.pdmodel"), + params_path( + "../../../../model/paddle_online_deepspeech/model/" + "avg_1.jit.pdiparams"), + thread_num(2), + use_gpu(false), + input_names( + "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_" + "box"), + output_names( + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1"), + cache_names("chunk_state_h_box,chunk_state_c_box"), + cache_shape("3-1-1024,3-1-1024"), + switch_ir_optim(false), + enable_fc_padding(false), + enable_profile(false) {} - void Register(kaldi::OptionsItf* opts) { - opts->Register("model-path", &model_path, "model file path"); - opts->Register("model-params", ¶ms_path, "params model file path"); - opts->Register("thread-num", &thread_num, "thread num"); - opts->Register("use-gpu", &use_gpu, "if use gpu"); - opts->Register("input-names", &input_names, "paddle input names"); - opts->Register("output-names", &output_names, "paddle output names"); - opts->Register("cache-names", &cache_names, "cache names"); - opts->Register("cache-shape", &cache_shape, "cache shape"); - opts->Register("switch-ir-optiom", &switch_ir_optim, "paddle SwitchIrOptim option"); - opts->Register("enable-fc-padding", &enable_fc_padding, "paddle EnableFCPadding option"); - opts->Register("enable-profile", &enable_profile, "paddle EnableProfile option"); - } + void Register(kaldi::OptionsItf* opts) { + opts->Register("model-path", &model_path, "model file path"); + opts->Register("model-params", ¶ms_path, "params model file path"); + opts->Register("thread-num", &thread_num, "thread num"); + opts->Register("use-gpu", &use_gpu, "if use gpu"); + opts->Register("input-names", &input_names, "paddle input names"); + opts->Register("output-names", &output_names, "paddle output names"); + opts->Register("cache-names", &cache_names, "cache names"); + opts->Register("cache-shape", &cache_shape, "cache shape"); + opts->Register("switch-ir-optiom", + &switch_ir_optim, + "paddle SwitchIrOptim option"); + opts->Register("enable-fc-padding", + &enable_fc_padding, + "paddle EnableFCPadding option"); + opts->Register( + "enable-profile", &enable_profile, "paddle EnableProfile option"); + } }; -template +template class Tensor { -public: - Tensor() { - } - Tensor(const std::vector& shape) : - _shape(shape) { - int data_size = std::accumulate(_shape.begin(), _shape.end(), - 1, std::multiplies()); + public: + Tensor() {} + Tensor(const std::vector& shape) : _shape(shape) { + int data_size = std::accumulate( + _shape.begin(), _shape.end(), 1, std::multiplies()); LOG(INFO) << "data size: " << data_size; _data.resize(data_size, 0); } void reshape(const std::vector& shape) { _shape = shape; - int data_size = std::accumulate(_shape.begin(), _shape.end(), - 1, std::multiplies()); + int data_size = std::accumulate( + _shape.begin(), _shape.end(), 1, std::multiplies()); _data.resize(data_size, 0); } - const std::vector& get_shape() const { - return _shape; - } - std::vector& get_data() { - return _data; - } -private: + const std::vector& get_shape() const { return _shape; } + std::vector& get_data() { return _data; } + + private: std::vector _shape; std::vector _data; }; @@ -85,15 +107,16 @@ private: class PaddleNnet : public NnetInterface { public: PaddleNnet(const ModelOptions& opts); - virtual void FeedForward(const kaldi::Matrix& features, + virtual void FeedForward(const kaldi::Matrix& features, kaldi::Matrix* inferences); virtual void Reset(); - std::shared_ptr> GetCacheEncoder(const std::string& name); + std::shared_ptr> GetCacheEncoder( + const std::string& name); void InitCacheEncouts(const ModelOptions& opts); - + private: paddle_infer::Predictor* GetPredictor(); - int ReleasePredictor(paddle_infer::Predictor* predictor); + int ReleasePredictor(paddle_infer::Predictor* predictor); std::unique_ptr pool; std::vector pool_usages; @@ -107,4 +130,4 @@ class PaddleNnet : public NnetInterface { DISALLOW_COPY_AND_ASSIGN(PaddleNnet); }; -} // namespace ppspeech +} // namespace ppspeech diff --git a/speechx/speechx/utils/file_utils.cc b/speechx/speechx/utils/file_utils.cc index 2d8be727..b8e51760 100644 --- a/speechx/speechx/utils/file_utils.cc +++ b/speechx/speechx/utils/file_utils.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "utils/file_utils.h" namespace ppspeech { @@ -17,5 +31,4 @@ bool ReadFileToVector(const std::string& filename, return true; } - } \ No newline at end of file diff --git a/speechx/speechx/utils/file_utils.h b/speechx/speechx/utils/file_utils.h index 0011b6c5..f82d41a5 100644 --- a/speechx/speechx/utils/file_utils.h +++ b/speechx/speechx/utils/file_utils.h @@ -1,8 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "base/common.h" namespace ppspeech { -bool ReadFileToVector(const std::string& filename, - std::vector* data); - +bool ReadFileToVector(const std::string& filename, + std::vector* data); }