From 869f4267d5fabbdf6f7b18515ebf33f28a755b6c Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Fri, 16 Dec 2022 12:17:03 +0800 Subject: [PATCH 01/50] [speechx]Speechx directory refactor (#2746) * refactor directory --- speechx/requirement.txt | 1 - speechx/speechx/CMakeLists.txt | 51 +- speechx/speechx/asr/CMakeLists.txt | 11 + .../speechx/{ => asr}/decoder/CMakeLists.txt | 0 speechx/speechx/{ => asr}/decoder/common.h | 0 .../decoder/ctc_beam_search_decoder.cc | 0 .../decoder/ctc_beam_search_decoder.h | 0 .../decoder/ctc_beam_search_decoder_main.cc | 0 .../{ => asr}/decoder/ctc_beam_search_opt.h | 0 .../asr/decoder/ctc_decoders/.gitignore | 9 + .../ctc_decoders/ctc_beam_search_decoder.cpp | 607 ++++++++++++++++++ .../ctc_decoders/ctc_beam_search_decoder.h | 175 +++++ .../ctc_decoders/ctc_greedy_decoder.cpp | 61 ++ .../decoder/ctc_decoders/ctc_greedy_decoder.h | 35 + .../decoder/ctc_decoders/decoder_utils.cpp | 193 ++++++ .../asr/decoder/ctc_decoders/decoder_utils.h | 111 ++++ .../asr/decoder/ctc_decoders/path_trie.cpp | 164 +++++ .../asr/decoder/ctc_decoders/path_trie.h | 82 +++ .../asr/decoder/ctc_decoders/scorer.cpp | 232 +++++++ .../speechx/asr/decoder/ctc_decoders/scorer.h | 114 ++++ .../decoder/ctc_prefix_beam_search_decoder.cc | 2 +- .../decoder/ctc_prefix_beam_search_decoder.h | 0 .../ctc_prefix_beam_search_decoder_main.cc | 0 .../decoder/ctc_prefix_beam_search_score.h | 0 .../{ => asr}/decoder/ctc_tlg_decoder.cc | 0 .../{ => asr}/decoder/ctc_tlg_decoder.h | 0 .../{ => asr}/decoder/ctc_tlg_decoder_main.cc | 0 .../speechx/{ => asr}/decoder/decoder_itf.h | 0 .../decoder/nnet_logprob_decoder_main.cc | 0 speechx/speechx/{ => asr}/decoder/param.h | 0 speechx/speechx/{ => asr}/nnet/CMakeLists.txt | 0 speechx/speechx/{ => asr}/nnet/decodable.cc | 0 speechx/speechx/{ => asr}/nnet/decodable.h | 0 speechx/speechx/{ => asr}/nnet/ds2_nnet.cc | 0 speechx/speechx/{ => asr}/nnet/ds2_nnet.h | 0 .../speechx/{ => asr}/nnet/ds2_nnet_main.cc | 0 speechx/speechx/{ => asr}/nnet/nnet_itf.h | 0 speechx/speechx/{ => asr}/nnet/u2_nnet.cc | 0 speechx/speechx/{ => asr}/nnet/u2_nnet.h | 0 .../speechx/{ => asr}/nnet/u2_nnet_main.cc | 0 .../{ => asr}/recognizer/CMakeLists.txt | 0 .../{ => asr}/recognizer/recognizer.cc | 0 .../speechx/{ => asr}/recognizer/recognizer.h | 0 .../{ => asr}/recognizer/recognizer_main.cc | 0 .../{ => asr}/recognizer/u2_recognizer.cc | 0 .../{ => asr}/recognizer/u2_recognizer.h | 0 .../recognizer/u2_recognizer_main.cc | 0 .../{protocol => asr/server}/CMakeLists.txt | 0 .../server}/websocket/CMakeLists.txt | 0 .../server}/websocket/websocket_client.cc | 0 .../server}/websocket/websocket_client.h | 0 .../websocket/websocket_client_main.cc | 0 .../server}/websocket/websocket_server.cc | 0 .../server}/websocket/websocket_server.h | 0 .../websocket/websocket_server_main.cc | 0 speechx/speechx/common/CMakeLists.txt | 16 + .../speechx/{ => common}/base/basic_types.h | 0 speechx/speechx/{ => common}/base/common.h | 0 speechx/speechx/{ => common}/base/flags.h | 0 speechx/speechx/{ => common}/base/log.h | 0 speechx/speechx/{ => common}/base/macros.h | 0 .../speechx/{ => common}/base/thread_pool.h | 0 .../{ => common}/frontend/CMakeLists.txt | 0 .../frontend/audio/CMakeLists.txt | 0 .../{ => common}/frontend/audio/assembler.cc | 0 .../{ => common}/frontend/audio/assembler.h | 0 .../frontend/audio/audio_cache.cc | 0 .../{ => common}/frontend/audio/audio_cache.h | 0 .../{ => common}/frontend/audio/cmvn.cc | 0 .../{ => common}/frontend/audio/cmvn.h | 0 .../frontend/audio/cmvn_json2kaldi_main.cc | 0 .../frontend/audio/compute_fbank_main.cc | 0 .../audio/compute_linear_spectrogram_main.cc | 0 .../{ => common}/frontend/audio/data_cache.h | 0 .../{ => common}/frontend/audio/db_norm.cc | 0 .../{ => common}/frontend/audio/db_norm.h | 0 .../{ => common}/frontend/audio/fbank.cc | 0 .../{ => common}/frontend/audio/fbank.h | 0 .../frontend/audio/feature_cache.cc | 0 .../frontend/audio/feature_cache.h | 0 .../frontend/audio/feature_common.h | 0 .../frontend/audio/feature_common_inl.h | 0 .../frontend/audio/feature_pipeline.cc | 0 .../frontend/audio/feature_pipeline.h | 0 .../frontend/audio/frontend_itf.h | 0 .../frontend/audio/linear_spectrogram.cc | 0 .../frontend/audio/linear_spectrogram.h | 0 .../{ => common}/frontend/audio/mfcc.cc | 0 .../{ => common}/frontend/audio/mfcc.h | 0 .../{ => common}/frontend/audio/normalizer.h | 0 .../speechx/{ => common}/utils/CMakeLists.txt | 0 .../speechx/{ => common}/utils/file_utils.cc | 0 .../speechx/{ => common}/utils/file_utils.h | 0 speechx/speechx/{ => common}/utils/math.cc | 0 speechx/speechx/{ => common}/utils/math.h | 0 speechx/speechx/decoder/ctc_decoders | 1 - speechx/speechx/frontend/text/CMakeLists.txt | 0 speechx/speechx/kaldi/CMakeLists.txt | 5 +- speechx/speechx/third_party/CMakeLists.txt | 0 speechx/speechx/third_party/README.md | 4 - 100 files changed, 1821 insertions(+), 53 deletions(-) delete mode 100644 speechx/requirement.txt create mode 100644 speechx/speechx/asr/CMakeLists.txt rename speechx/speechx/{ => asr}/decoder/CMakeLists.txt (100%) rename speechx/speechx/{ => asr}/decoder/common.h (100%) rename speechx/speechx/{ => asr}/decoder/ctc_beam_search_decoder.cc (100%) rename speechx/speechx/{ => asr}/decoder/ctc_beam_search_decoder.h (100%) rename speechx/speechx/{ => asr}/decoder/ctc_beam_search_decoder_main.cc (100%) rename speechx/speechx/{ => asr}/decoder/ctc_beam_search_opt.h (100%) create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/.gitignore create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.h create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp create mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.h rename speechx/speechx/{ => asr}/decoder/ctc_prefix_beam_search_decoder.cc (99%) rename speechx/speechx/{ => asr}/decoder/ctc_prefix_beam_search_decoder.h (100%) rename speechx/speechx/{ => asr}/decoder/ctc_prefix_beam_search_decoder_main.cc (100%) rename speechx/speechx/{ => asr}/decoder/ctc_prefix_beam_search_score.h (100%) rename speechx/speechx/{ => asr}/decoder/ctc_tlg_decoder.cc (100%) rename speechx/speechx/{ => asr}/decoder/ctc_tlg_decoder.h (100%) rename speechx/speechx/{ => asr}/decoder/ctc_tlg_decoder_main.cc (100%) rename speechx/speechx/{ => asr}/decoder/decoder_itf.h (100%) rename speechx/speechx/{ => asr}/decoder/nnet_logprob_decoder_main.cc (100%) rename speechx/speechx/{ => asr}/decoder/param.h (100%) rename speechx/speechx/{ => asr}/nnet/CMakeLists.txt (100%) rename speechx/speechx/{ => asr}/nnet/decodable.cc (100%) rename speechx/speechx/{ => asr}/nnet/decodable.h (100%) rename speechx/speechx/{ => asr}/nnet/ds2_nnet.cc (100%) rename speechx/speechx/{ => asr}/nnet/ds2_nnet.h (100%) rename speechx/speechx/{ => asr}/nnet/ds2_nnet_main.cc (100%) rename speechx/speechx/{ => asr}/nnet/nnet_itf.h (100%) rename speechx/speechx/{ => asr}/nnet/u2_nnet.cc (100%) rename speechx/speechx/{ => asr}/nnet/u2_nnet.h (100%) rename speechx/speechx/{ => asr}/nnet/u2_nnet_main.cc (100%) rename speechx/speechx/{ => asr}/recognizer/CMakeLists.txt (100%) rename speechx/speechx/{ => asr}/recognizer/recognizer.cc (100%) rename speechx/speechx/{ => asr}/recognizer/recognizer.h (100%) rename speechx/speechx/{ => asr}/recognizer/recognizer_main.cc (100%) rename speechx/speechx/{ => asr}/recognizer/u2_recognizer.cc (100%) rename speechx/speechx/{ => asr}/recognizer/u2_recognizer.h (100%) rename speechx/speechx/{ => asr}/recognizer/u2_recognizer_main.cc (100%) rename speechx/speechx/{protocol => asr/server}/CMakeLists.txt (100%) rename speechx/speechx/{protocol => asr/server}/websocket/CMakeLists.txt (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_client.cc (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_client.h (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_client_main.cc (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_server.cc (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_server.h (100%) rename speechx/speechx/{protocol => asr/server}/websocket/websocket_server_main.cc (100%) create mode 100644 speechx/speechx/common/CMakeLists.txt rename speechx/speechx/{ => common}/base/basic_types.h (100%) rename speechx/speechx/{ => common}/base/common.h (100%) rename speechx/speechx/{ => common}/base/flags.h (100%) rename speechx/speechx/{ => common}/base/log.h (100%) rename speechx/speechx/{ => common}/base/macros.h (100%) rename speechx/speechx/{ => common}/base/thread_pool.h (100%) rename speechx/speechx/{ => common}/frontend/CMakeLists.txt (100%) rename speechx/speechx/{ => common}/frontend/audio/CMakeLists.txt (100%) rename speechx/speechx/{ => common}/frontend/audio/assembler.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/assembler.h (100%) rename speechx/speechx/{ => common}/frontend/audio/audio_cache.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/audio_cache.h (100%) rename speechx/speechx/{ => common}/frontend/audio/cmvn.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/cmvn.h (100%) rename speechx/speechx/{ => common}/frontend/audio/cmvn_json2kaldi_main.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/compute_fbank_main.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/compute_linear_spectrogram_main.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/data_cache.h (100%) rename speechx/speechx/{ => common}/frontend/audio/db_norm.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/db_norm.h (100%) rename speechx/speechx/{ => common}/frontend/audio/fbank.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/fbank.h (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_cache.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_cache.h (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_common.h (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_common_inl.h (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_pipeline.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/feature_pipeline.h (100%) rename speechx/speechx/{ => common}/frontend/audio/frontend_itf.h (100%) rename speechx/speechx/{ => common}/frontend/audio/linear_spectrogram.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/linear_spectrogram.h (100%) rename speechx/speechx/{ => common}/frontend/audio/mfcc.cc (100%) rename speechx/speechx/{ => common}/frontend/audio/mfcc.h (100%) rename speechx/speechx/{ => common}/frontend/audio/normalizer.h (100%) rename speechx/speechx/{ => common}/utils/CMakeLists.txt (100%) rename speechx/speechx/{ => common}/utils/file_utils.cc (100%) rename speechx/speechx/{ => common}/utils/file_utils.h (100%) rename speechx/speechx/{ => common}/utils/math.cc (100%) rename speechx/speechx/{ => common}/utils/math.h (100%) delete mode 120000 speechx/speechx/decoder/ctc_decoders delete mode 100644 speechx/speechx/frontend/text/CMakeLists.txt delete mode 100644 speechx/speechx/third_party/CMakeLists.txt delete mode 100644 speechx/speechx/third_party/README.md diff --git a/speechx/requirement.txt b/speechx/requirement.txt deleted file mode 100644 index 6a6db096..00000000 --- a/speechx/requirement.txt +++ /dev/null @@ -1 +0,0 @@ -paddlepaddle>=2.4rc diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 60c18347..b522e158 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -2,50 +2,11 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(speechx LANGUAGES CXX) -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/kaldi -) -add_subdirectory(kaldi) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/utils -) -add_subdirectory(utils) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/frontend -) -add_subdirectory(frontend) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/nnet -) -add_subdirectory(nnet) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common) -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/decoder -) -add_subdirectory(decoder) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/recognizer -) -add_subdirectory(recognizer) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/protocol -) -add_subdirectory(protocol) - -include_directories( -${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/codelab -) +add_subdirectory(asr) +add_subdirectory(common) +add_subdirectory(kaldi) add_subdirectory(codelab) diff --git a/speechx/speechx/asr/CMakeLists.txt b/speechx/speechx/asr/CMakeLists.txt new file mode 100644 index 00000000..ff4cdecb --- /dev/null +++ b/speechx/speechx/asr/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(ASR LANGUAGES CXX) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server) + +add_subdirectory(decoder) +add_subdirectory(recognizer) +add_subdirectory(nnet) +add_subdirectory(server) diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt similarity index 100% rename from speechx/speechx/decoder/CMakeLists.txt rename to speechx/speechx/asr/decoder/CMakeLists.txt diff --git a/speechx/speechx/decoder/common.h b/speechx/speechx/asr/decoder/common.h similarity index 100% rename from speechx/speechx/decoder/common.h rename to speechx/speechx/asr/decoder/common.h diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc similarity index 100% rename from speechx/speechx/decoder/ctc_beam_search_decoder.cc rename to speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h similarity index 100% rename from speechx/speechx/decoder/ctc_beam_search_decoder.h rename to speechx/speechx/asr/decoder/ctc_beam_search_decoder.h diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc similarity index 100% rename from speechx/speechx/decoder/ctc_beam_search_decoder_main.cc rename to speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/speechx/speechx/asr/decoder/ctc_beam_search_opt.h similarity index 100% rename from speechx/speechx/decoder/ctc_beam_search_opt.h rename to speechx/speechx/asr/decoder/ctc_beam_search_opt.h diff --git a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore b/speechx/speechx/asr/decoder/ctc_decoders/.gitignore new file mode 100644 index 00000000..0b1046ae --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/.gitignore @@ -0,0 +1,9 @@ +ThreadPool/ +build/ +dist/ +kenlm/ +openfst-1.6.3/ +openfst-1.6.3.tar.gz +swig_decoders.egg-info/ +decoders_wrap.cxx +swig_decoders.py diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp new file mode 100644 index 00000000..ebea5c22 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp @@ -0,0 +1,607 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ctc_beam_search_decoder.h" + +#include +#include +#include +#include +#include +#include + +#include "ThreadPool.h" +#include "fst/fstlib.h" + +#include "decoder_utils.h" +#include "path_trie.h" + +using FSTMATCH = fst::SortedMatcher; + + +std::vector> ctc_beam_search_decoding( + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + // vocabulary.size() + 1, + vocabulary.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + + // assign space id + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + // init prefixes' root + PathTrie root; + root.score = root.log_prob_b_prev = 0.0; + std::vector prefixes; + prefixes.push_back(&root); + + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = + static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); + } + + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = log_sum_exp( + prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = + log_sum_exp(prefix->log_prob_nb_cur, + log_prob_c + prefix->log_prob_nb_prev); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based())) { + PathTrie *prefix_to_score = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + score = ext_scorer->get_log_cond_prob(ngram) * + ext_scorer->alpha; + log_p += score; + log_p += ext_scorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + } // end of loop over vocabulary + + + prefixes.clear(); + // update log probs + root.iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } + } // end of loop over time + + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + + // compute approximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (ext_scorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= + (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; + } + + return get_beam_search_result(prefixes, vocabulary, beam_size); +} + + +std::vector>> +ctc_beam_search_decoding_batch( + const std::vector>> &probs_split, + const std::vector &vocabulary, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) { + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(num_processes); + // number of samples + size_t batch_size = probs_split.size(); + + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoding, + probs_split[i], + vocabulary, + beam_size, + cutoff_prob, + cutoff_top_n, + ext_scorer, + blank_id)); + } + + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} + +void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) { + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = + static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root->set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root->set_matcher(matcher); + } +} + +void ctc_beam_search_decode_chunk( + PathTrie *root, + std::vector &prefixes, + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + // vocabulary.size() + 1, + vocabulary.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + // assign space id + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + // init prefixes' root + // + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); + } + + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = log_sum_exp( + prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = + log_sum_exp(prefix->log_prob_nb_cur, + log_prob_c + prefix->log_prob_nb_prev); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based())) { + PathTrie *prefix_to_score = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + score = ext_scorer->get_log_cond_prob(ngram) * + ext_scorer->alpha; + log_p += score; + log_p += ext_scorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + } // end of loop over vocabulary + + prefixes.clear(); + // update log probs + + root->iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } + } // end of loop over time + + return; +} + + +std::vector> get_decode_result( + std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size, + Scorer *ext_scorer) { + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (ext_scorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= + (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; + } + + std::vector> res = + get_beam_search_result(prefixes, vocabulary, beam_size); + + // pay back the last word of each prefix that doesn't end with space (for + // decoding by chunk) + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score -= score; + } + } + } + return res; +} + + +void free_storage(std::unique_ptr &storage) { + storage = nullptr; +} + + +CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {} + +CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch( + const std::vector &vocabulary, + size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) + : batch_size(batch_size), + beam_size(beam_size), + num_processes(num_processes), + cutoff_prob(cutoff_prob), + cutoff_top_n(cutoff_top_n), + ext_scorer(ext_scorer), + blank_id(blank_id) { + VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + this->vocabulary = vocabulary; + for (size_t i = 0; i < batch_size; i++) { + this->decoder_storage_vector.push_back( + std::unique_ptr( + new CtcBeamSearchDecoderStorage())); + ctc_beam_search_decode_chunk_begin( + this->decoder_storage_vector[i]->root, ext_scorer); + } +}; + +/** + * Input + * probs_split: shape [B, T, D] + */ +void CtcBeamSearchDecoderBatch::next( + const std::vector>> &probs_split, + const std::vector &has_value) { + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + size_t num_has_value = 0; + for (int i = 0; i < has_value.size(); i++) + if (has_value[i] == "true") num_has_value += 1; + ThreadPool pool(std::min(num_processes, num_has_value)); + // number of samples + size_t probs_num = probs_split.size(); + VALID_CHECK_EQ(this->batch_size, + probs_num, + "The batch size of the current input data should be same " + "with the input data before"); + + // enqueue the tasks of decoding + std::vector> res; + for (size_t i = 0; i < batch_size; ++i) { + if (has_value[i] == "true") { + res.emplace_back(pool.enqueue( + ctc_beam_search_decode_chunk, + std::ref(this->decoder_storage_vector[i]->root), + std::ref(this->decoder_storage_vector[i]->prefixes), + probs_split[i], + this->vocabulary, + this->beam_size, + this->cutoff_prob, + this->cutoff_top_n, + this->ext_scorer, + this->blank_id)); + } + } + + for (size_t i = 0; i < batch_size; ++i) { + res[i].get(); + } + return; +}; + +/** + * Return + * batch_result: shape[B, beam_size,(-approx_ctc score, string)] + */ +std::vector>> +CtcBeamSearchDecoderBatch::decode() { + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(this->num_processes); + // number of samples + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < this->batch_size; ++i) { + res.emplace_back( + pool.enqueue(get_decode_result, + std::ref(this->decoder_storage_vector[i]->prefixes), + this->vocabulary, + this->beam_size, + this->ext_scorer)); + } + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < this->batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} + + +/** + * reset the state of ctcBeamSearchDecoderBatch + */ +void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n) { + this->batch_size = batch_size; + this->beam_size = beam_size; + this->num_processes = num_processes; + this->cutoff_prob = cutoff_prob; + this->cutoff_top_n = cutoff_top_n; + + VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(this->num_processes); + // number of samples + // enqueue the tasks of decoding + std::vector> res; + size_t storage_size = decoder_storage_vector.size(); + for (size_t i = 0; i < storage_size; i++) { + res.emplace_back(pool.enqueue( + free_storage, std::ref(this->decoder_storage_vector[i]))); + } + for (size_t i = 0; i < storage_size; ++i) { + res[i].get(); + } + std::vector>().swap( + decoder_storage_vector); + for (size_t i = 0; i < this->batch_size; i++) { + this->decoder_storage_vector.push_back( + std::unique_ptr( + new CtcBeamSearchDecoderStorage())); + ctc_beam_search_decode_chunk_begin( + this->decoder_storage_vector[i]->root, this->ext_scorer); + } +} \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h new file mode 100644 index 00000000..92d2b855 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h @@ -0,0 +1,175 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CTC_BEAM_SEARCH_DECODER_H_ +#define CTC_BEAM_SEARCH_DECODER_H_ + +#include +#include +#include + +#include "scorer.h" + +/* CTC Beam Search Decoder + + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A vector that each element is a pair of score and decoding result, + * in desending order. +*/ +std::vector> ctc_beam_search_decoding( + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob = 1.0, + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); + + +/* CTC Beam Search Decoder for batch data + + * Parameters: + * probs_seq: 3-D vector that each element is a 2-D vector that can be used + * by ctc_beam_search_decoder(). + * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. + * num_processes: Number of threads for beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * A 2-D vector that each element is a vector of beam search decoding + * result for one audio sample. +*/ +std::vector>> +ctc_beam_search_decoding_batch( + const std::vector>> &probs_split, + const std::vector &vocabulary, + size_t beam_size, + size_t num_processes, + double cutoff_prob = 1.0, + size_t cutoff_top_n = 40, + Scorer *ext_scorer = nullptr, + size_t blank_id = 0); + +/** + * Store the root and prefixes for decoder + */ + +class CtcBeamSearchDecoderStorage { + public: + PathTrie *root = nullptr; + std::vector prefixes; + + CtcBeamSearchDecoderStorage() { + // init prefixes' root + this->root = new PathTrie(); + this->root->log_prob_b_prev = 0.0; + // The score of root is in log scale.Since the prob=1.0, the prob score + // in log scale is 0.0 + this->root->score = root->log_prob_b_prev; + // std::vector prefixes; + this->prefixes.push_back(root); + }; + + ~CtcBeamSearchDecoderStorage() { + if (root != nullptr) { + delete root; + root = nullptr; + } + }; +}; + +/** + * The ctc beam search decoder, support batchsize >= 1 + */ +class CtcBeamSearchDecoderBatch { + public: + CtcBeamSearchDecoderBatch(const std::vector &vocabulary, + size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id); + + ~CtcBeamSearchDecoderBatch(); + void next(const std::vector>> &probs_split, + const std::vector &has_value); + + std::vector>> decode(); + + void reset_state(size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n); + + private: + std::vector vocabulary; + size_t batch_size; + size_t beam_size; + size_t num_processes; + double cutoff_prob; + size_t cutoff_top_n; + Scorer *ext_scorer; + size_t blank_id; + std::vector> + decoder_storage_vector; +}; + +/** + * function for chunk decoding + */ +void ctc_beam_search_decode_chunk( + PathTrie *root, + std::vector &prefixes, + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id); + +std::vector> get_decode_result( + std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size, + Scorer *ext_scorer); + +/** + * free the CtcBeamSearchDecoderStorage + */ +void free_storage(std::unique_ptr &storage); + +/** + * initialize the root + */ +void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer); + +#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp new file mode 100644 index 00000000..6aa3c996 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ctc_greedy_decoder.h" +#include "decoder_utils.h" + +std::string ctc_greedy_decoding( + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t blank_id) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + vocabulary.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + // size_t blank_id = vocabulary.size(); + + std::vector max_idx_vec(num_time_steps, 0); + std::vector idx_vec; + for (size_t i = 0; i < num_time_steps; ++i) { + double max_prob = 0.0; + size_t max_idx = 0; + const std::vector &probs_step = probs_seq[i]; + for (size_t j = 0; j < probs_step.size(); ++j) { + if (max_prob < probs_step[j]) { + max_idx = j; + max_prob = probs_step[j]; + } + } + // id with maximum probability in current time step + max_idx_vec[i] = max_idx; + // deduplicate + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { + idx_vec.push_back(max_idx_vec[i]); + } + } + + std::string best_path_result; + for (size_t i = 0; i < idx_vec.size(); ++i) { + if (idx_vec[i] != blank_id) { + std::string ch = vocabulary[idx_vec[i]]; + best_path_result += (ch == kSPACE) ? tSPACE : ch; + } + } + return best_path_result; +} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h new file mode 100644 index 00000000..4451600d --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h @@ -0,0 +1,35 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CTC_GREEDY_DECODER_H +#define CTC_GREEDY_DECODER_H + +#include +#include + +/* CTC Greedy (Best Path) Decoder + * + * Parameters: + * probs_seq: 2-D vector that each element is a vector of probabilities + * over vocabulary of one time step. + * vocabulary: A vector of vocabulary. + * Return: + * The decoding result in string + */ +std::string ctc_greedy_decoding( + const std::vector>& probs_seq, + const std::vector& vocabulary, + size_t blank_id); + +#endif // CTC_GREEDY_DECODER_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp new file mode 100644 index 00000000..c7ef6542 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp @@ -0,0 +1,193 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder_utils.h" + +#include +#include +#include + +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n) { + std::vector> prob_idx; + for (size_t i = 0; i < prob_step.size(); ++i) { + prob_idx.push_back(std::pair(i, prob_step[i])); + } + // pruning of vocabulary + size_t cutoff_len = prob_step.size(); + if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { + std::sort(prob_idx.begin(), + prob_idx.end(), + pair_comp_second_rev); + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (size_t i = 0; i < prob_idx.size(); ++i) { + cum_prob += prob_idx[i].second; + cutoff_len += 1; + if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) + break; + } + } + prob_idx = std::vector>( + prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (size_t i = 0; i < cutoff_len; ++i) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); + } + return log_prob_idx; +} + + +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size) { + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + space_prefixes.push_back(prefixes[i]); + } + } + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (size_t j = 0; j < output.size(); j++) { + std::string ch = vocabulary[output[j]]; + output_str += (ch == kSPACE) ? tSPACE : ch; + } + std::pair output_pair( + -space_prefixes[i]->approx_ctc, output_str); + output_vecs.emplace_back(output_pair); + } + + return output_vecs; +} + +size_t get_utf8_str_len(const std::string &str) { + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; +} + +std::vector split_utf8_str(const std::string &str) { + std::vector result; + std::string out_str; + + for (char c : str) { + if ((c & 0xc0) != 0x80) // new UTF-8 character + { + if (!out_str.empty()) { + result.push_back(out_str); + out_str.clear(); + } + } + + out_str.append(1, c); + } + result.push_back(out_str); + return result; +} + +std::vector split_str(const std::string &s, + const std::string &delim) { + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; + } + return result; +} + +bool prefix_compare(const PathTrie *x, const PathTrie *y) { + if (x->score == y->score) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } + } else { + return x->score > y->score; + } +} + +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary) { + if (dictionary->NumStates() == 0) { + fst::StdVectorFst::StateId start = dictionary->AddState(); + assert(start == 0); + dictionary->SetStart(start); + } + fst::StdVectorFst::StateId src = dictionary->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = dictionary->AddState(); + dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + dictionary->SetFinal(dst, fst::StdArc::Weight::One()); +} + +bool add_word_to_dictionary( + const std::string &word, + const std::unordered_map &char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst *dictionary) { + auto characters = split_utf8_str(word); + + std::vector int_word; + + for (auto &c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without adding + } + } + } + + if (add_space) { + int_word.push_back(SPACE_ID); + } + + add_word_to_fst(int_word, dictionary); + return true; // return with successful adding +} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h new file mode 100644 index 00000000..09874155 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h @@ -0,0 +1,111 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DECODER_UTILS_H_ +#define DECODER_UTILS_H_ + +#include +#include +#include "fst/log.h" +#include "path_trie.h" + +const std::string kSPACE = ""; +const std::string tSPACE = " "; +const float NUM_FLT_INF = std::numeric_limits::max(); +const float NUM_FLT_MIN = std::numeric_limits::min(); + +// inline function for validation check +inline void check( + bool x, const char *expr, const char *file, int line, const char *err) { + if (!x) { + std::cout << "[" << file << ":" << line << "] "; + LOG(FATAL) << "\"" << expr << "\" check failed. " << err; + } +} + +#define VALID_CHECK(x, info) \ + check(static_cast(x), #x, __FILE__, __LINE__, info) +#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) +#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) +#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) + + +// Function template for comparing two pairs +template +bool pair_comp_first_rev(const std::pair &a, + const std::pair &b) { + return a.first > b.first; +} + +// Function template for comparing two pairs +template +bool pair_comp_second_rev(const std::pair &a, + const std::pair &b) { + return a.second > b.second; +} + +// Return the sum of two probabilities in log scale +template +T log_sum_exp(const T &x, const T &y) { + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) return y; + if (y <= num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; +} + +// Get pruned probability vector for each time step's beam search +std::vector> get_pruned_log_probs( + const std::vector &prob_step, + double cutoff_prob, + size_t cutoff_top_n); + +// Get beam search result from prefixes in trie tree +std::vector> get_beam_search_result( + const std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size); + +// Functor for prefix comparsion +bool prefix_compare(const PathTrie *x, const PathTrie *y); + +/* Get length of utf8 encoding string + * See: http://stackoverflow.com/a/4063229 + */ +size_t get_utf8_str_len(const std::string &str); + +/* Split a string into a list of strings on a given string + * delimiter. NB: delimiters on beginning / end of string are + * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. + */ +std::vector split_str(const std::string &s, + const std::string &delim); + +/* Splits string into vector of strings representing + * UTF-8 characters (not same as chars) + */ +std::vector split_utf8_str(const std::string &str); + +// Add a word in index to the dicionary of fst +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary); + +// Add a word in string to dictionary +bool add_word_to_dictionary( + const std::string &word, + const std::unordered_map &char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst *dictionary); +#endif // DECODER_UTILS_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp new file mode 100644 index 00000000..777ca052 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "path_trie.h" + +#include +#include +#include +#include +#include + +#include "decoder_utils.h" + +PathTrie::PathTrie() { + log_prob_b_prev = -NUM_FLT_INF; + log_prob_nb_prev = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + score = -NUM_FLT_INF; + + ROOT_ = -1; + character = ROOT_; + exists_ = true; + parent = nullptr; + + dictionary_ = nullptr; + dictionary_state_ = 0; + has_dictionary_ = false; + + matcher_ = nullptr; +} + +PathTrie::~PathTrie() { + for (auto child : children_) { + delete child.second; + child.second = nullptr; + } +} + +PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { + auto child = children_.begin(); + for (child = children_.begin(); child != children_.end(); ++child) { + if (child->first == new_char) { + break; + } + } + if (child != children_.end()) { + if (!child->second->exists_) { + child->second->exists_ = true; + child->second->log_prob_b_prev = -NUM_FLT_INF; + child->second->log_prob_nb_prev = -NUM_FLT_INF; + child->second->log_prob_b_cur = -NUM_FLT_INF; + child->second->log_prob_nb_cur = -NUM_FLT_INF; + } + return (child->second); + } else { + if (has_dictionary_) { + matcher_->SetState(dictionary_state_); + bool found = matcher_->Find(new_char + 1); + if (!found) { + // Adding this character causes word outside dictionary + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = dictionary_->Final(dictionary_state_); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + dictionary_state_ = dictionary_->Start(); + } + return nullptr; + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->dictionary_ = dictionary_; + new_path->dictionary_state_ = matcher_->Value().nextstate; + new_path->has_dictionary_ = true; + new_path->matcher_ = matcher_; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } +} + +PathTrie* PathTrie::get_path_vec(std::vector& output) { + return get_path_vec(output, ROOT_); +} + +PathTrie* PathTrie::get_path_vec(std::vector& output, + int stop, + size_t max_steps) { + if (character == stop || character == ROOT_ || output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + return this; + } else { + output.push_back(character); + return parent->get_path_vec(output, stop, max_steps); + } +} + +void PathTrie::iterate_to_vec(std::vector& output) { + if (exists_) { + log_prob_b_prev = log_prob_b_cur; + log_prob_nb_prev = log_prob_nb_cur; + + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + + score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + output.push_back(this); + } + for (auto child : children_) { + child.second->iterate_to_vec(output); + } +} + +void PathTrie::remove() { + exists_ = false; + if (children_.size() == 0) { + if (parent != nullptr) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); + child != parent->children_.end(); + ++child) { + if (child->first == character) { + parent->children_.erase(child); + break; + } + } + if (parent->children_.size() == 0 && !parent->exists_) { + parent->remove(); + } + } + delete this; + } +} + + +void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { + dictionary_ = dictionary; + dictionary_state_ = dictionary->Start(); + has_dictionary_ = true; +} + +using FSTMATCH = fst::SortedMatcher; +void PathTrie::set_matcher(std::shared_ptr matcher) { + matcher_ = matcher; +} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h new file mode 100644 index 00000000..5193e0a4 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PATH_TRIE_H +#define PATH_TRIE_H + +#include +#include +#include +#include +#include + +#include "fst/fstlib.h" + +/* Trie tree for prefix storing and manipulating, with a dictionary in + * finite-state transducer for spelling correction. + */ +class PathTrie { + public: + PathTrie(); + ~PathTrie(); + + // get new prefix after appending new char + PathTrie* get_path_trie(int new_char, bool reset = true); + + // get the prefix in index from root to current node + PathTrie* get_path_vec(std::vector& output); + + // get the prefix in index from some stop node to current nodel + PathTrie* get_path_vec( + std::vector& output, + int stop, + size_t max_steps = std::numeric_limits::max()); + + // update log probs + void iterate_to_vec(std::vector& output); + + // set dictionary for FST + void set_dictionary(fst::StdVectorFst* dictionary); + + void set_matcher(std::shared_ptr>); + + bool is_empty() { return ROOT_ == character; } + + // remove current path from root + void remove(); + + float log_prob_b_prev; + float log_prob_nb_prev; + float log_prob_b_cur; + float log_prob_nb_cur; + float score; + float approx_ctc; + int character; + PathTrie* parent; + + private: + int ROOT_; + bool exists_; + bool has_dictionary_; + + std::vector> children_; + + // pointer to dictionary of FST + fst::StdVectorFst* dictionary_; + fst::StdVectorFst::StateId dictionary_state_; + // true if finding ars in FST + std::shared_ptr> matcher_; +}; + +#endif // PATH_TRIE_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp b/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp new file mode 100644 index 00000000..6e7f68cf --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp @@ -0,0 +1,232 @@ +// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the +// "COPYING.LESSER.3"); + +#include "scorer.h" + +#include +#include + +#include "lm/config.hh" +#include "lm/model.hh" +#include "lm/state.hh" + +#include "decoder_utils.h" + +using namespace lm::ngram; +// if your platform is windows ,you need add the define +#define F_OK 0 +Scorer::Scorer(double alpha, + double beta, + const std::string& lm_path, + const std::vector& vocab_list) { + this->alpha = alpha; + this->beta = beta; + + dictionary = nullptr; + is_character_based_ = true; + language_model_ = nullptr; + + max_order_ = 0; + dict_size_ = 0; + SPACE_ID_ = -1; + + setup(lm_path, vocab_list); +} + +Scorer::~Scorer() { + if (language_model_ != nullptr) { + delete static_cast(language_model_); + } + if (dictionary != nullptr) { + delete static_cast(dictionary); + } +} + +void Scorer::setup(const std::string& lm_path, + const std::vector& vocab_list) { + // load language model + load_lm(lm_path); + // set char map for scorer + set_char_map(vocab_list); + // fill the dictionary for FST + if (!is_character_based()) { + fill_dictionary(true); + } +} + +void Scorer::load_lm(const std::string& lm_path) { + const char* filename = lm_path.c_str(); + VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); + + RetriveStrEnumerateVocab enumerate; + lm::ngram::Config config; + config.enumerate_vocab = &enumerate; + language_model_ = lm::ngram::LoadVirtual(filename, config); + max_order_ = static_cast(language_model_)->Order(); + vocabulary_ = enumerate.vocabulary; + for (size_t i = 0; i < vocabulary_.size(); ++i) { + if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && + vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && + get_utf8_str_len(enumerate.vocabulary[i]) > 1) { + is_character_based_ = false; + } + } +} + +double Scorer::get_log_cond_prob(const std::vector& words) { + lm::base::Model* model = static_cast(language_model_); + double cond_prob; + lm::ngram::State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV + if (word_index == 0) { + return OOV_SCORE; + } + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; + state = out_state; + out_state = tmp_state; + } + // return log10 prob + return cond_prob; +} + +double Scorer::get_sent_log_prob(const std::vector& words) { + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < max_order_; ++i) { + sentence.push_back(START_TOKEN); + } + } else { + for (size_t i = 0; i < max_order_ - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); + } + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); +} + +double Scorer::get_log_prob(const std::vector& words) { + assert(words.size() > max_order_); + double score = 0.0; + for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { + std::vector ngram(words.begin() + i, + words.begin() + i + max_order_); + score += get_log_cond_prob(ngram); + } + return score; +} + +void Scorer::reset_params(float alpha, float beta) { + this->alpha = alpha; + this->beta = beta; +} + +std::string Scorer::vec2str(const std::vector& input) { + std::string word; + for (auto ind : input) { + word += char_list_[ind]; + } + return word; +} + +std::vector Scorer::split_labels(const std::vector& labels) { + if (labels.empty()) return {}; + + std::string s = vec2str(labels); + std::vector words; + if (is_character_based_) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } + return words; +} + +void Scorer::set_char_map(const std::vector& char_list) { + char_list_ = char_list; + char_map_.clear(); + + // Set the char map for the FST for spelling correction + for (size_t i = 0; i < char_list_.size(); i++) { + if (char_list_[i] == kSPACE) { + SPACE_ID_ = i; + } + // The initial state of FST is state 0, hence the index of chars in + // the FST should start from 1 to avoid the conflict with the initial + // state, otherwise wrong decoding results would be given. + char_map_[char_list_[i]] = i + 1; + } +} + +std::vector Scorer::make_ngram(PathTrie* prefix) { + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; + + for (int order = 0; order < max_order_; order++) { + std::vector prefix_vec; + + if (is_character_based_) { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); + current_node = new_node; + } else { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); + current_node = new_node->parent; // Skipping spaces + } + + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); + + if (new_node->character == -1) { + // No more spaces, but still need order + for (int i = 0; i < max_order_ - order - 1; i++) { + ngram.push_back(START_TOKEN); + } + break; + } + } + std::reverse(ngram.begin(), ngram.end()); + return ngram; +} + +void Scorer::fill_dictionary(bool add_space) { + fst::StdVectorFst dictionary; + // For each unigram convert to ints and put in trie + int dict_size = 0; + for (const auto& word : vocabulary_) { + bool added = add_word_to_dictionary( + word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); + dict_size += added ? 1 : 0; + } + + dict_size_ = dict_size; + + /* Simplify FST + + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST deterministic, but + * can greatly increase the size of the FST + */ + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + /* This makes the FST deterministic, meaning for any string input there's + * only one possible state the FST could be in. It is assumed our + * dictionary is deterministic when using it. + * (lest we'd have to check for multiple transitions at each state) + */ + fst::Determinize(dictionary, new_dict); + + /* Finds the simplest equivalent fst. This is unnecessary but decreases + * memory usage of the dictionary + */ + fst::Minimize(new_dict); + this->dictionary = new_dict; +} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h b/speechx/speechx/asr/decoder/ctc_decoders/scorer.h new file mode 100644 index 00000000..08e109b7 --- /dev/null +++ b/speechx/speechx/asr/decoder/ctc_decoders/scorer.h @@ -0,0 +1,114 @@ +// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the +// "COPYING.LESSER.3"); + +#ifndef SCORER_H_ +#define SCORER_H_ + +#include +#include +#include +#include + +#include "lm/enumerate_vocab.hh" +#include "lm/virtual_interface.hh" +#include "lm/word_index.hh" + +#include "path_trie.h" + +const double OOV_SCORE = -1000.0; +const std::string START_TOKEN = ""; +const std::string UNK_TOKEN = ""; +const std::string END_TOKEN = ""; + +// Implement a callback to retrive the dictionary of language model. +class RetriveStrEnumerateVocab : public lm::EnumerateVocab { + public: + RetriveStrEnumerateVocab() {} + + void Add(lm::WordIndex index, const StringPiece &str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } + + std::vector vocabulary; +}; + +/* External scorer to query score for n-gram or sentence, including language + * model scoring and word insertion. + * + * Example: + * Scorer scorer(alpha, beta, "path_of_language_model"); + * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); + * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); + */ +class Scorer { + public: + Scorer(double alpha, + double beta, + const std::string &lm_path, + const std::vector &vocabulary); + ~Scorer(); + + double get_log_cond_prob(const std::vector &words); + + double get_sent_log_prob(const std::vector &words); + + // return the max order + size_t get_max_order() const { return max_order_; } + + // return the dictionary size of language model + size_t get_dict_size() const { return dict_size_; } + + // retrun true if the language model is character based + bool is_character_based() const { return is_character_based_; } + + // reset params alpha & beta + void reset_params(float alpha, float beta); + + // make ngram for a given prefix + std::vector make_ngram(PathTrie *prefix); + + // trransform the labels in index to the vector of words (word based lm) or + // the vector of characters (character based lm) + std::vector split_labels(const std::vector &labels); + + // language model weight + double alpha; + // word insertion weight + double beta; + + // pointer to the dictionary of FST + void *dictionary; + + protected: + // necessary setup: load language model, set char map, fill FST's dictionary + void setup(const std::string &lm_path, + const std::vector &vocab_list); + + // load language model from given path + void load_lm(const std::string &lm_path); + + // fill dictionary for FST + void fill_dictionary(bool add_space); + + // set char map + void set_char_map(const std::vector &char_list); + + double get_log_prob(const std::vector &words); + + // translate the vector in index to string + std::string vec2str(const std::vector &input); + + private: + void *language_model_; + bool is_character_based_; + size_t max_order_; + size_t dict_size_; + + int SPACE_ID_; + std::vector char_list_; + std::unordered_map char_map_; + + std::vector vocabulary_; +}; + +#endif // SCORER_H_ diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc similarity index 99% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc rename to speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 07e8e560..15dbd7e9 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -84,7 +84,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( timer.Reset(); std::vector> likelihood; - likelihood.push_back(frame_prob); + likelihood.push_back(std::move(frame_prob)); AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h similarity index 100% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h rename to speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc similarity index 100% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc rename to speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_score.h similarity index 100% rename from speechx/speechx/decoder/ctc_prefix_beam_search_score.h rename to speechx/speechx/asr/decoder/ctc_prefix_beam_search_score.h diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc similarity index 100% rename from speechx/speechx/decoder/ctc_tlg_decoder.cc rename to speechx/speechx/asr/decoder/ctc_tlg_decoder.cc diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/asr/decoder/ctc_tlg_decoder.h similarity index 100% rename from speechx/speechx/decoder/ctc_tlg_decoder.h rename to speechx/speechx/asr/decoder/ctc_tlg_decoder.h diff --git a/speechx/speechx/decoder/ctc_tlg_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc similarity index 100% rename from speechx/speechx/decoder/ctc_tlg_decoder_main.cc rename to speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc diff --git a/speechx/speechx/decoder/decoder_itf.h b/speechx/speechx/asr/decoder/decoder_itf.h similarity index 100% rename from speechx/speechx/decoder/decoder_itf.h rename to speechx/speechx/asr/decoder/decoder_itf.h diff --git a/speechx/speechx/decoder/nnet_logprob_decoder_main.cc b/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc similarity index 100% rename from speechx/speechx/decoder/nnet_logprob_decoder_main.cc rename to speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/asr/decoder/param.h similarity index 100% rename from speechx/speechx/decoder/param.h rename to speechx/speechx/asr/decoder/param.h diff --git a/speechx/speechx/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt similarity index 100% rename from speechx/speechx/nnet/CMakeLists.txt rename to speechx/speechx/asr/nnet/CMakeLists.txt diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc similarity index 100% rename from speechx/speechx/nnet/decodable.cc rename to speechx/speechx/asr/nnet/decodable.cc diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h similarity index 100% rename from speechx/speechx/nnet/decodable.h rename to speechx/speechx/asr/nnet/decodable.h diff --git a/speechx/speechx/nnet/ds2_nnet.cc b/speechx/speechx/asr/nnet/ds2_nnet.cc similarity index 100% rename from speechx/speechx/nnet/ds2_nnet.cc rename to speechx/speechx/asr/nnet/ds2_nnet.cc diff --git a/speechx/speechx/nnet/ds2_nnet.h b/speechx/speechx/asr/nnet/ds2_nnet.h similarity index 100% rename from speechx/speechx/nnet/ds2_nnet.h rename to speechx/speechx/asr/nnet/ds2_nnet.h diff --git a/speechx/speechx/nnet/ds2_nnet_main.cc b/speechx/speechx/asr/nnet/ds2_nnet_main.cc similarity index 100% rename from speechx/speechx/nnet/ds2_nnet_main.cc rename to speechx/speechx/asr/nnet/ds2_nnet_main.cc diff --git a/speechx/speechx/nnet/nnet_itf.h b/speechx/speechx/asr/nnet/nnet_itf.h similarity index 100% rename from speechx/speechx/nnet/nnet_itf.h rename to speechx/speechx/asr/nnet/nnet_itf.h diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/asr/nnet/u2_nnet.cc similarity index 100% rename from speechx/speechx/nnet/u2_nnet.cc rename to speechx/speechx/asr/nnet/u2_nnet.cc diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/asr/nnet/u2_nnet.h similarity index 100% rename from speechx/speechx/nnet/u2_nnet.h rename to speechx/speechx/asr/nnet/u2_nnet.h diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/speechx/speechx/asr/nnet/u2_nnet_main.cc similarity index 100% rename from speechx/speechx/nnet/u2_nnet_main.cc rename to speechx/speechx/asr/nnet/u2_nnet_main.cc diff --git a/speechx/speechx/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt similarity index 100% rename from speechx/speechx/recognizer/CMakeLists.txt rename to speechx/speechx/asr/recognizer/CMakeLists.txt diff --git a/speechx/speechx/recognizer/recognizer.cc b/speechx/speechx/asr/recognizer/recognizer.cc similarity index 100% rename from speechx/speechx/recognizer/recognizer.cc rename to speechx/speechx/asr/recognizer/recognizer.cc diff --git a/speechx/speechx/recognizer/recognizer.h b/speechx/speechx/asr/recognizer/recognizer.h similarity index 100% rename from speechx/speechx/recognizer/recognizer.h rename to speechx/speechx/asr/recognizer/recognizer.h diff --git a/speechx/speechx/recognizer/recognizer_main.cc b/speechx/speechx/asr/recognizer/recognizer_main.cc similarity index 100% rename from speechx/speechx/recognizer/recognizer_main.cc rename to speechx/speechx/asr/recognizer/recognizer_main.cc diff --git a/speechx/speechx/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc similarity index 100% rename from speechx/speechx/recognizer/u2_recognizer.cc rename to speechx/speechx/asr/recognizer/u2_recognizer.cc diff --git a/speechx/speechx/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h similarity index 100% rename from speechx/speechx/recognizer/u2_recognizer.h rename to speechx/speechx/asr/recognizer/u2_recognizer.h diff --git a/speechx/speechx/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc similarity index 100% rename from speechx/speechx/recognizer/u2_recognizer_main.cc rename to speechx/speechx/asr/recognizer/u2_recognizer_main.cc diff --git a/speechx/speechx/protocol/CMakeLists.txt b/speechx/speechx/asr/server/CMakeLists.txt similarity index 100% rename from speechx/speechx/protocol/CMakeLists.txt rename to speechx/speechx/asr/server/CMakeLists.txt diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/speechx/speechx/asr/server/websocket/CMakeLists.txt similarity index 100% rename from speechx/speechx/protocol/websocket/CMakeLists.txt rename to speechx/speechx/asr/server/websocket/CMakeLists.txt diff --git a/speechx/speechx/protocol/websocket/websocket_client.cc b/speechx/speechx/asr/server/websocket/websocket_client.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.cc rename to speechx/speechx/asr/server/websocket/websocket_client.cc diff --git a/speechx/speechx/protocol/websocket/websocket_client.h b/speechx/speechx/asr/server/websocket/websocket_client.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.h rename to speechx/speechx/asr/server/websocket/websocket_client.h diff --git a/speechx/speechx/protocol/websocket/websocket_client_main.cc b/speechx/speechx/asr/server/websocket/websocket_client_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client_main.cc rename to speechx/speechx/asr/server/websocket/websocket_client_main.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server.cc b/speechx/speechx/asr/server/websocket/websocket_server.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server.cc rename to speechx/speechx/asr/server/websocket/websocket_server.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server.h b/speechx/speechx/asr/server/websocket/websocket_server.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server.h rename to speechx/speechx/asr/server/websocket/websocket_server.h diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/speechx/speechx/asr/server/websocket/websocket_server_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server_main.cc rename to speechx/speechx/asr/server/websocket/websocket_server_main.cc diff --git a/speechx/speechx/common/CMakeLists.txt b/speechx/speechx/common/CMakeLists.txt new file mode 100644 index 00000000..dea9eb05 --- /dev/null +++ b/speechx/speechx/common/CMakeLists.txt @@ -0,0 +1,16 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/base +) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR}/../ +${CMAKE_CURRENT_SOURCE_DIR}/utils +) +add_subdirectory(utils) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/frontend +) +add_subdirectory(frontend) diff --git a/speechx/speechx/base/basic_types.h b/speechx/speechx/common/base/basic_types.h similarity index 100% rename from speechx/speechx/base/basic_types.h rename to speechx/speechx/common/base/basic_types.h diff --git a/speechx/speechx/base/common.h b/speechx/speechx/common/base/common.h similarity index 100% rename from speechx/speechx/base/common.h rename to speechx/speechx/common/base/common.h diff --git a/speechx/speechx/base/flags.h b/speechx/speechx/common/base/flags.h similarity index 100% rename from speechx/speechx/base/flags.h rename to speechx/speechx/common/base/flags.h diff --git a/speechx/speechx/base/log.h b/speechx/speechx/common/base/log.h similarity index 100% rename from speechx/speechx/base/log.h rename to speechx/speechx/common/base/log.h diff --git a/speechx/speechx/base/macros.h b/speechx/speechx/common/base/macros.h similarity index 100% rename from speechx/speechx/base/macros.h rename to speechx/speechx/common/base/macros.h diff --git a/speechx/speechx/base/thread_pool.h b/speechx/speechx/common/base/thread_pool.h similarity index 100% rename from speechx/speechx/base/thread_pool.h rename to speechx/speechx/common/base/thread_pool.h diff --git a/speechx/speechx/frontend/CMakeLists.txt b/speechx/speechx/common/frontend/CMakeLists.txt similarity index 100% rename from speechx/speechx/frontend/CMakeLists.txt rename to speechx/speechx/common/frontend/CMakeLists.txt diff --git a/speechx/speechx/frontend/audio/CMakeLists.txt b/speechx/speechx/common/frontend/audio/CMakeLists.txt similarity index 100% rename from speechx/speechx/frontend/audio/CMakeLists.txt rename to speechx/speechx/common/frontend/audio/CMakeLists.txt diff --git a/speechx/speechx/frontend/audio/assembler.cc b/speechx/speechx/common/frontend/audio/assembler.cc similarity index 100% rename from speechx/speechx/frontend/audio/assembler.cc rename to speechx/speechx/common/frontend/audio/assembler.cc diff --git a/speechx/speechx/frontend/audio/assembler.h b/speechx/speechx/common/frontend/audio/assembler.h similarity index 100% rename from speechx/speechx/frontend/audio/assembler.h rename to speechx/speechx/common/frontend/audio/assembler.h diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/common/frontend/audio/audio_cache.cc similarity index 100% rename from speechx/speechx/frontend/audio/audio_cache.cc rename to speechx/speechx/common/frontend/audio/audio_cache.cc diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/common/frontend/audio/audio_cache.h similarity index 100% rename from speechx/speechx/frontend/audio/audio_cache.h rename to speechx/speechx/common/frontend/audio/audio_cache.h diff --git a/speechx/speechx/frontend/audio/cmvn.cc b/speechx/speechx/common/frontend/audio/cmvn.cc similarity index 100% rename from speechx/speechx/frontend/audio/cmvn.cc rename to speechx/speechx/common/frontend/audio/cmvn.cc diff --git a/speechx/speechx/frontend/audio/cmvn.h b/speechx/speechx/common/frontend/audio/cmvn.h similarity index 100% rename from speechx/speechx/frontend/audio/cmvn.h rename to speechx/speechx/common/frontend/audio/cmvn.h diff --git a/speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc similarity index 100% rename from speechx/speechx/frontend/audio/cmvn_json2kaldi_main.cc rename to speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc diff --git a/speechx/speechx/frontend/audio/compute_fbank_main.cc b/speechx/speechx/common/frontend/audio/compute_fbank_main.cc similarity index 100% rename from speechx/speechx/frontend/audio/compute_fbank_main.cc rename to speechx/speechx/common/frontend/audio/compute_fbank_main.cc diff --git a/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc b/speechx/speechx/common/frontend/audio/compute_linear_spectrogram_main.cc similarity index 100% rename from speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc rename to speechx/speechx/common/frontend/audio/compute_linear_spectrogram_main.cc diff --git a/speechx/speechx/frontend/audio/data_cache.h b/speechx/speechx/common/frontend/audio/data_cache.h similarity index 100% rename from speechx/speechx/frontend/audio/data_cache.h rename to speechx/speechx/common/frontend/audio/data_cache.h diff --git a/speechx/speechx/frontend/audio/db_norm.cc b/speechx/speechx/common/frontend/audio/db_norm.cc similarity index 100% rename from speechx/speechx/frontend/audio/db_norm.cc rename to speechx/speechx/common/frontend/audio/db_norm.cc diff --git a/speechx/speechx/frontend/audio/db_norm.h b/speechx/speechx/common/frontend/audio/db_norm.h similarity index 100% rename from speechx/speechx/frontend/audio/db_norm.h rename to speechx/speechx/common/frontend/audio/db_norm.h diff --git a/speechx/speechx/frontend/audio/fbank.cc b/speechx/speechx/common/frontend/audio/fbank.cc similarity index 100% rename from speechx/speechx/frontend/audio/fbank.cc rename to speechx/speechx/common/frontend/audio/fbank.cc diff --git a/speechx/speechx/frontend/audio/fbank.h b/speechx/speechx/common/frontend/audio/fbank.h similarity index 100% rename from speechx/speechx/frontend/audio/fbank.h rename to speechx/speechx/common/frontend/audio/fbank.h diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/common/frontend/audio/feature_cache.cc similarity index 100% rename from speechx/speechx/frontend/audio/feature_cache.cc rename to speechx/speechx/common/frontend/audio/feature_cache.cc diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/common/frontend/audio/feature_cache.h similarity index 100% rename from speechx/speechx/frontend/audio/feature_cache.h rename to speechx/speechx/common/frontend/audio/feature_cache.h diff --git a/speechx/speechx/frontend/audio/feature_common.h b/speechx/speechx/common/frontend/audio/feature_common.h similarity index 100% rename from speechx/speechx/frontend/audio/feature_common.h rename to speechx/speechx/common/frontend/audio/feature_common.h diff --git a/speechx/speechx/frontend/audio/feature_common_inl.h b/speechx/speechx/common/frontend/audio/feature_common_inl.h similarity index 100% rename from speechx/speechx/frontend/audio/feature_common_inl.h rename to speechx/speechx/common/frontend/audio/feature_common_inl.h diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/common/frontend/audio/feature_pipeline.cc similarity index 100% rename from speechx/speechx/frontend/audio/feature_pipeline.cc rename to speechx/speechx/common/frontend/audio/feature_pipeline.cc diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/common/frontend/audio/feature_pipeline.h similarity index 100% rename from speechx/speechx/frontend/audio/feature_pipeline.h rename to speechx/speechx/common/frontend/audio/feature_pipeline.h diff --git a/speechx/speechx/frontend/audio/frontend_itf.h b/speechx/speechx/common/frontend/audio/frontend_itf.h similarity index 100% rename from speechx/speechx/frontend/audio/frontend_itf.h rename to speechx/speechx/common/frontend/audio/frontend_itf.h diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/common/frontend/audio/linear_spectrogram.cc similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.cc rename to speechx/speechx/common/frontend/audio/linear_spectrogram.cc diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/common/frontend/audio/linear_spectrogram.h similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.h rename to speechx/speechx/common/frontend/audio/linear_spectrogram.h diff --git a/speechx/speechx/frontend/audio/mfcc.cc b/speechx/speechx/common/frontend/audio/mfcc.cc similarity index 100% rename from speechx/speechx/frontend/audio/mfcc.cc rename to speechx/speechx/common/frontend/audio/mfcc.cc diff --git a/speechx/speechx/frontend/audio/mfcc.h b/speechx/speechx/common/frontend/audio/mfcc.h similarity index 100% rename from speechx/speechx/frontend/audio/mfcc.h rename to speechx/speechx/common/frontend/audio/mfcc.h diff --git a/speechx/speechx/frontend/audio/normalizer.h b/speechx/speechx/common/frontend/audio/normalizer.h similarity index 100% rename from speechx/speechx/frontend/audio/normalizer.h rename to speechx/speechx/common/frontend/audio/normalizer.h diff --git a/speechx/speechx/utils/CMakeLists.txt b/speechx/speechx/common/utils/CMakeLists.txt similarity index 100% rename from speechx/speechx/utils/CMakeLists.txt rename to speechx/speechx/common/utils/CMakeLists.txt diff --git a/speechx/speechx/utils/file_utils.cc b/speechx/speechx/common/utils/file_utils.cc similarity index 100% rename from speechx/speechx/utils/file_utils.cc rename to speechx/speechx/common/utils/file_utils.cc diff --git a/speechx/speechx/utils/file_utils.h b/speechx/speechx/common/utils/file_utils.h similarity index 100% rename from speechx/speechx/utils/file_utils.h rename to speechx/speechx/common/utils/file_utils.h diff --git a/speechx/speechx/utils/math.cc b/speechx/speechx/common/utils/math.cc similarity index 100% rename from speechx/speechx/utils/math.cc rename to speechx/speechx/common/utils/math.cc diff --git a/speechx/speechx/utils/math.h b/speechx/speechx/common/utils/math.h similarity index 100% rename from speechx/speechx/utils/math.h rename to speechx/speechx/common/utils/math.h diff --git a/speechx/speechx/decoder/ctc_decoders b/speechx/speechx/decoder/ctc_decoders deleted file mode 120000 index b280de09..00000000 --- a/speechx/speechx/decoder/ctc_decoders +++ /dev/null @@ -1 +0,0 @@ -../../../third_party/ctc_decoders \ No newline at end of file diff --git a/speechx/speechx/frontend/text/CMakeLists.txt b/speechx/speechx/frontend/text/CMakeLists.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index ce6b43f6..d27668fc 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -1,4 +1,7 @@ project(kaldi) +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +) add_subdirectory(base) add_subdirectory(util) @@ -10,4 +13,4 @@ add_subdirectory(decoder) add_subdirectory(lm) add_subdirectory(fstbin) -add_subdirectory(lmbin) \ No newline at end of file +add_subdirectory(lmbin) diff --git a/speechx/speechx/third_party/CMakeLists.txt b/speechx/speechx/third_party/CMakeLists.txt deleted file mode 100644 index e69de29b..00000000 diff --git a/speechx/speechx/third_party/README.md b/speechx/speechx/third_party/README.md deleted file mode 100644 index 2d620335..00000000 --- a/speechx/speechx/third_party/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# third party - -Those libs copied and developed from third pary opensource software projects. -For all of these things, the official websites are the best place to go. From f8caaf46c8c35dbecb879cc2d4acea0de13bb45d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 27 Dec 2022 16:30:57 +0800 Subject: [PATCH 02/50] refactor cmake, rm absl/linsndfile, add strings unittest (#2765) --- speechx/CMakeLists.txt | 54 ++++----- speechx/README.md | 8 ++ speechx/cmake/gtest.cmake | 8 +- speechx/cmake/openfst.cmake | 2 + speechx/cmake/system.cmake | 106 ++++++++++++++++++ speechx/speechx/asr/decoder/CMakeLists.txt | 2 +- .../decoder/ctc_prefix_beam_search_decoder.cc | 1 - .../ctc_prefix_beam_search_decoder_main.cc | 1 - speechx/speechx/asr/nnet/CMakeLists.txt | 2 +- speechx/speechx/asr/nnet/ds2_nnet.cc | 12 +- .../asr/server/websocket/CMakeLists.txt | 2 +- speechx/speechx/common/utils/CMakeLists.txt | 16 ++- speechx/speechx/common/utils/math.cc | 5 +- speechx/speechx/common/utils/strings.cc | 50 +++++++++ speechx/speechx/common/utils/strings.h | 26 +++++ speechx/speechx/common/utils/strings_test.cc | 35 ++++++ 16 files changed, 283 insertions(+), 47 deletions(-) create mode 100644 speechx/cmake/system.cmake create mode 100644 speechx/speechx/common/utils/strings.cc create mode 100644 speechx/speechx/common/utils/strings.h create mode 100644 speechx/speechx/common/utils/strings_test.cc diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 6b957160..ed5c38f0 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -1,19 +1,28 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) -project(paddlespeech VERSION 0.1) - set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake") +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +include(system) + +# Ninja Generator will set CMAKE_BUILD_TYPE to Debug +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE + "Release" + CACHE + STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" + FORCE) +endif() + +project(paddlespeech VERSION 0.1) + set(CMAKE_VERBOSE_MAKEFILE on) # set std-14 set(CMAKE_CXX_STANDARD 14) -# cmake dir -set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake) - -# Modules -list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}) include(FetchContent) include(ExternalProject) @@ -33,6 +42,7 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall ############################################################################### option(TEST_DEBUG "option for debug" OFF) option(USE_PROFILING "enable c++ profling" OFF) +option(WITH_TESTING "unit test" ON) option(USING_U2 "compile u2 model." ON) option(USING_DS2 "compile with ds2 model." ON) @@ -42,26 +52,10 @@ option(USING_GPU "u2 compute on GPU." OFF) ############################################################################### # Include third party ############################################################################### -# example for include third party -# FetchContent_MakeAvailable was not added until CMake 3.14 -# FetchContent_MakeAvailable() -# include_directories() - -# gflags include(gflags) -# glog include(glog) -# gtest -include(gtest) - -# ABSEIL-CPP -include(absl) - -# libsndfile -include(libsndfile) - # boost # include(boost) # not work set(boost_SOURCE_DIR ${fc_patch}/boost-src) @@ -87,6 +81,11 @@ add_dependencies(openfst gflags glog) # paddle lib include(paddleinference) +# gtest +if(WITH_TESTING) + include(gtest) # download, build, install gtest +endif() + # python/pybind11/threads find_package(Threads REQUIRED) # https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 @@ -165,15 +164,6 @@ message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) ############################################################################### # Add local library ############################################################################### -# system lib -#find_package() -# if dir have CmakeLists.txt -#add_subdirectory(speechx) -# if dir do not have CmakeLists.txt -#add_library(lib_name STATIC file.cc) -#target_link_libraries(lib_name item0 item1) -#add_dependencies(lib_name depend-target) - set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) add_subdirectory(speechx) diff --git a/speechx/README.md b/speechx/README.md index 5d4b5845..70136ea0 100644 --- a/speechx/README.md +++ b/speechx/README.md @@ -113,3 +113,11 @@ apt-get install gfortran-8 4. `Undefined reference to '_gfortran_concat_string'` using gcc 8.2, gfortran 8.2. + +5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory` + +``` +apt-get install python3-dev +``` + +for more info please see [here](https://github.com/okfn/piati/issues/65). diff --git a/speechx/cmake/gtest.cmake b/speechx/cmake/gtest.cmake index 1ea8ed0b..365f25cf 100644 --- a/speechx/cmake/gtest.cmake +++ b/speechx/cmake/gtest.cmake @@ -1,3 +1,4 @@ + include(FetchContent) FetchContent_Declare( gtest @@ -6,4 +7,9 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(gtest) -include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) \ No newline at end of file +include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) + + +if(WITH_TESTING) + enable_testing() +endif() \ No newline at end of file diff --git a/speechx/cmake/openfst.cmake b/speechx/cmake/openfst.cmake index 07c33a74..8861f4f4 100644 --- a/speechx/cmake/openfst.cmake +++ b/speechx/cmake/openfst.cmake @@ -25,3 +25,5 @@ ExternalProject_Add(openfst ) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) +message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include") +message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib") \ No newline at end of file diff --git a/speechx/cmake/system.cmake b/speechx/cmake/system.cmake new file mode 100644 index 00000000..580e07bb --- /dev/null +++ b/speechx/cmake/system.cmake @@ -0,0 +1,106 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Detects the OS and sets appropriate variables. +# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is +# building for, but the host processor name like centos is necessary +# in some scenes to distinguish system for customization. +# +# for instance, protobuf libs path is /lib64 +# on CentOS, but /lib on other systems. + +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif() + +if(WIN32) + set(HOST_SYSTEM "win32") +else() + if(APPLE) + set(HOST_SYSTEM "macosx") + exec_program( + sw_vers ARGS + -productVersion + OUTPUT_VARIABLE HOST_SYSTEM_VERSION) + string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}") + if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET}) + # Set cache variable - end user may change this during ccmake or cmake-gui configure. + set(CMAKE_OSX_DEPLOYMENT_TARGET + ${MACOS_VERSION} + CACHE + STRING + "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value." + ) + endif() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + else() + + if(EXISTS "/etc/issue") + file(READ "/etc/issue" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + elseif(LINUX_ISSUE MATCHES "Debian") + set(HOST_SYSTEM "debian") + elseif(LINUX_ISSUE MATCHES "Ubuntu") + set(HOST_SYSTEM "ubuntu") + elseif(LINUX_ISSUE MATCHES "Red Hat") + set(HOST_SYSTEM "redhat") + elseif(LINUX_ISSUE MATCHES "Fedora") + set(HOST_SYSTEM "fedora") + endif() + + string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION + "${LINUX_ISSUE}") + endif() + + if(EXISTS "/etc/redhat-release") + file(READ "/etc/redhat-release" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + endif() + endif() + + if(NOT HOST_SYSTEM) + set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME}) + endif() + + endif() +endif() + +# query number of logical cores +cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES) + +mark_as_advanced(HOST_SYSTEM CPU_CORES) + +message( + STATUS + "Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}") +message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores") + +# external dependencies log output +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD + 0 # Wrap download in script to log output + LOG_UPDATE + 1 # Wrap update in script to log output + LOG_CONFIGURE + 1 # Wrap configure in script to log output + LOG_BUILD + 0 # Wrap build in script to log output + LOG_TEST + 1 # Wrap test in script to log output + LOG_INSTALL + 0 # Wrap install in script to log output +) \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index f0fd32ba..93014fb9 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -19,7 +19,7 @@ if (USING_U2) endif() add_library(decoder STATIC ${srcs}) -target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) +target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) # test if (USING_DS2) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 15dbd7e9..2cef4972 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -17,7 +17,6 @@ #include "decoder/ctc_prefix_beam_search_decoder.h" -#include "absl/strings/str_join.h" #include "base/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_score.h" diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index c59b1f2e..31c8b19e 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/strings/str_split.h" #include "base/common.h" #include "decoder/ctc_prefix_beam_search_decoder.h" #include "frontend/audio/data_cache.h" diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 43566616..27081086 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -9,7 +9,7 @@ if(USING_U2) endif() add_library(nnet STATIC ${srcs}) -target_link_libraries(nnet absl::strings) +target_link_libraries(nnet utils) if(USING_U2) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) diff --git a/speechx/speechx/asr/nnet/ds2_nnet.cc b/speechx/speechx/asr/nnet/ds2_nnet.cc index 22c7f61b..f77c0a60 100644 --- a/speechx/speechx/asr/nnet/ds2_nnet.cc +++ b/speechx/speechx/asr/nnet/ds2_nnet.cc @@ -14,7 +14,7 @@ #include "nnet/ds2_nnet.h" -#include "absl/strings/str_split.h" +#include "utils/strings.h" namespace ppspeech { @@ -26,16 +26,16 @@ using std::vector; void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; - cache_names = absl::StrSplit(opts.cache_names, ","); + cache_names = StrSplit(opts.cache_names, ","); std::vector cache_shapes; - cache_shapes = absl::StrSplit(opts.cache_shape, ","); + cache_shapes = StrSplit(opts.cache_shape, ","); assert(cache_shapes.size() == cache_names.size()); cache_encouts_.clear(); cache_names_idx_.clear(); for (size_t i = 0; i < cache_shapes.size(); i++) { std::vector tmp_shape; - tmp_shape = absl::StrSplit(cache_shapes[i], "-"); + tmp_shape = StrSplit(cache_shapes[i], "-"); std::vector cur_shape; std::transform(tmp_shape.begin(), tmp_shape.end(), @@ -74,8 +74,8 @@ PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { LOG(INFO) << "start to check the predictor input and output names"; LOG(INFO) << "input names: " << opts.input_names; LOG(INFO) << "output names: " << opts.output_names; - vector input_names_vec = absl::StrSplit(opts.input_names, ","); - vector output_names_vec = absl::StrSplit(opts.output_names, ","); + std::vector input_names_vec = StrSplit(opts.input_names, ","); + std::vector output_names_vec = StrSplit(opts.output_names, ","); paddle_infer::Predictor* predictor = GetPredictor(); diff --git a/speechx/speechx/asr/server/websocket/CMakeLists.txt b/speechx/speechx/asr/server/websocket/CMakeLists.txt index cafbbec7..9991e47b 100644 --- a/speechx/speechx/asr/server/websocket/CMakeLists.txt +++ b/speechx/speechx/asr/server/websocket/CMakeLists.txt @@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS}) add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) +target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) \ No newline at end of file diff --git a/speechx/speechx/common/utils/CMakeLists.txt b/speechx/speechx/common/utils/CMakeLists.txt index c1e875be..c47b25c0 100644 --- a/speechx/speechx/common/utils/CMakeLists.txt +++ b/speechx/speechx/common/utils/CMakeLists.txt @@ -2,4 +2,18 @@ add_library(utils file_utils.cc math.cc -) \ No newline at end of file + strings.cc +) + + +if(WITH_TESTING) + enable_testing() + link_libraries(gtest_main gmock) + + add_executable(strings_test strings_test.cc) + target_link_libraries(strings_test PUBLIC utils) + add_test( + NAME strings_test + COMMAND strings_test + ) +endif() \ No newline at end of file diff --git a/speechx/speechx/common/utils/math.cc b/speechx/speechx/common/utils/math.cc index 71656cb3..e5832cbd 100644 --- a/speechx/speechx/common/utils/math.cc +++ b/speechx/speechx/common/utils/math.cc @@ -15,13 +15,14 @@ // limitations under the License. #include "utils/math.h" +#include "base/basic_types.h" #include #include #include #include - -#include "base/common.h" +#include +#include namespace ppspeech { diff --git a/speechx/speechx/common/utils/strings.cc b/speechx/speechx/common/utils/strings.cc new file mode 100644 index 00000000..6aa8af47 --- /dev/null +++ b/speechx/speechx/common/utils/strings.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "utils/strings.h" + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, const char *delim, bool omit_empty_string){ + std::vector outs; + int start = 0; + int end = str.size(); + int found = 0; + while(found != std::string::npos){ + found = str.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_string || (found != start && start != end)){ + outs.push_back(str.substr(start, found - start)); + } + start = found + 1; + } + + return outs; +} + + +std::string StrJoin(const std::vector& strs, const char* delim) { + std::stringstream ss; + for (ssize_t i = 0; i < strs.size(); ++i){ + ss << strs[i]; + if ( i < strs.size() -1){ + ss << std::string(delim); + } + } + return ss.str(); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/common/utils/strings.h b/speechx/speechx/common/utils/strings.h new file mode 100644 index 00000000..e2629164 --- /dev/null +++ b/speechx/speechx/common/utils/strings.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, const char *delim, bool omit_empty_string=true); + +std::string StrJoin(const std::vector& strs, const char* delim); + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/common/utils/strings_test.cc b/speechx/speechx/common/utils/strings_test.cc new file mode 100644 index 00000000..a2950d32 --- /dev/null +++ b/speechx/speechx/common/utils/strings_test.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "utils/strings.h" + +#include +#include + + +TEST(StringTest, StrSplitTest) { + using ::testing::ElementsAre; + + std::string test_str = "hello world"; + std::vector outs = ppspeech::StrSplit(test_str, " \t"); + EXPECT_THAT(outs, ElementsAre("hello", "world")); +} + + +TEST(StringTest, StrJoinTest) { + std::vector ins{"hello", "world"}; + std::string out = ppspeech::StrJoin(ins, " "); + EXPECT_THAT(out, "hello world"); +} \ No newline at end of file From 5046d8ee9416904fd3fec8d8d802286bc46e84b3 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Tue, 27 Dec 2022 19:50:59 +0800 Subject: [PATCH 03/50] [Speechx] add nnet prob cache && make 2 thread decode work (#2769) * add nnet cache && make 2 thread work * do not compile websocket --- speechx/CMakeLists.txt | 2 +- .../ctc_prefix_beam_search_decoder_main.cc | 13 +- speechx/speechx/asr/nnet/CMakeLists.txt | 22 ++-- speechx/speechx/asr/nnet/decodable.cc | 88 ++++--------- speechx/speechx/asr/nnet/decodable.h | 16 +-- speechx/speechx/asr/nnet/nnet_producer.cc | 84 ++++++++++++ speechx/speechx/asr/nnet/nnet_producer.h | 73 +++++++++++ speechx/speechx/asr/recognizer/CMakeLists.txt | 1 + .../speechx/asr/recognizer/u2_recognizer.cc | 15 ++- .../speechx/asr/recognizer/u2_recognizer.h | 10 +- .../recognizer/u2_recognizer_thread_main.cc | 123 ++++++++++++++++++ speechx/speechx/asr/server/CMakeLists.txt | 2 +- speechx/speechx/common/base/common.h | 2 +- speechx/speechx/common/base/safe_queue.h | 71 ++++++++++ 14 files changed, 415 insertions(+), 107 deletions(-) create mode 100644 speechx/speechx/asr/nnet/nnet_producer.cc create mode 100644 speechx/speechx/asr/nnet/nnet_producer.h create mode 100644 speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc create mode 100644 speechx/speechx/common/base/safe_queue.h diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index ed5c38f0..45bf5419 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -45,7 +45,7 @@ option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) option(USING_U2 "compile u2 model." ON) -option(USING_DS2 "compile with ds2 model." ON) +option(USING_DS2 "compile with ds2 model." OFF) option(USING_GPU "u2 compute on GPU." OFF) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index 31c8b19e..31276895 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -18,6 +18,7 @@ #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" #include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); @@ -39,7 +40,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test ds2 online decoder by feeding speech feature +// test u2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -69,8 +70,10 @@ int main(int argc, char* argv[]) { // decodeable std::shared_ptr raw_data = std::make_shared(); + std::shared_ptr nnet_producer = + std::make_shared(nnet, raw_data); std::shared_ptr decodable = - std::make_shared(nnet, raw_data); + std::make_shared(nnet_producer); // decoder ppspeech::CTCBeamSearchOptions opts; @@ -114,9 +117,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) - << "utt: " << utt << " skip last " << this_chunk_size - << " frames, expect is " << receptive_field_length; + LOG(WARNING) << "utt: " << utt << " skip last " + << this_chunk_size << " frames, expect is " + << receptive_field_length; break; } diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 27081086..2846540e 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -1,4 +1,4 @@ -set(srcs decodable.cc) +set(srcs decodable.cc nnet_producer.cc) if(USING_DS2) list(APPEND srcs ds2_nnet.cc) @@ -27,13 +27,13 @@ if(USING_DS2) endif() # test bin -if(USING_U2) - set(bin_name u2_nnet_main) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) - - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -endif() +#if(USING_U2) +# set(bin_name u2_nnet_main) +# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) + +# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +#endif() diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index 5fe2b984..f01e9049 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -21,19 +21,16 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, +Decodable::Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale) - : frontend_(frontend), - nnet_(nnet), + : nnet_producer_(nnet_producer), frame_offset_(0), frames_ready_(0), acoustic_scale_(acoustic_scale) {} // for debug void Decodable::Acceptlikelihood(const Matrix& likelihood) { - nnet_out_cache_ = likelihood; - frames_ready_ += likelihood.NumRows(); + nnet_producer_->Acceptlikelihood(likelihood); } @@ -43,7 +40,7 @@ int32 Decodable::NumFramesReady() const { return frames_ready_; } // frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { - bool flag = EnsureFrameHaveComputed(frame); + EnsureFrameHaveComputed(frame); return frame >= frames_ready_; } @@ -64,32 +61,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::AdvanceChunk() { kaldi::Timer timer; - // read feats - Vector features; - if (frontend_ == NULL || frontend_->Read(&features) == false) { - // no feat or frontend_ not init. - VLOG(3) << "decodable exit;"; - return false; - } - CHECK_GE(frontend_->Dim(), 0); - VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec."; - VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; - - // forward feats - NnetOut out; - nnet_->FeedForward(features, frontend_->Dim(), &out); - int32& vocab_dim = out.vocab_dim; - Vector& logprobs = out.logprobs; - - VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim - << " decoder frames."; - // cache nnet outupts - nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim); - nnet_out_cache_.CopyRowsFromVec(logprobs); - - // update state, decoding frame. + bool flag = nnet_producer_->Read(&framelikelihood_); + if (flag == false) return false; frame_offset_ = frames_ready_; - frames_ready_ += nnet_out_cache_.NumRows(); + frames_ready_ += 1; VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() << " sec."; return true; @@ -101,17 +76,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - if (nrows <= 0) { + if (framelikelihood_.empty()) { LOG(WARNING) << "No new nnet out in cache."; return false; } - logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); - logprobs->CopyRowsFromMat(nnet_out_cache_); - - *vocab_dim = nnet_out_cache_.NumCols(); + size_t dim = framelikelihood_.size(); + logprobs->Resize(framelikelihood_.size()); + std::memcpy(logprobs->Data(), + framelikelihood_.data(), + dim * sizeof(kaldi::BaseFloat)); + *vocab_dim = framelikelihood_.size(); return true; } @@ -122,19 +97,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - int vocab_size = nnet_out_cache_.NumCols(); - likelihood->resize(vocab_size); - - for (int32 idx = 0; idx < vocab_size; ++idx) { - (*likelihood)[idx] = - nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_; - - VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " " - << nnet_out_cache_.NumRows() - << " logprob: " << nnet_out_cache_(frame - frame_offset_, idx); - } + CHECK_EQ(1, (frames_ready_ - frame_offset_)); + *likelihood = framelikelihood_; return true; } @@ -143,37 +107,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return false; } - CHECK_LE(index, nnet_out_cache_.NumCols()); + CHECK_LE(index, framelikelihood_.size()); CHECK_LE(frame, frames_ready_); // the nnet output is prob ranther than log prob // the index - 1, because the ilabel BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; - BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); - if (nnet_->IsLogProb()) { - logprob = nnet_out; - } else { - logprob = std::log(nnet_out + std::numeric_limits::epsilon()); - } - CHECK(!std::isnan(logprob) && !std::isinf(logprob)); + CHECK_EQ(frame_idx, 0); + logprob = framelikelihood_[TokenId2NnetId(index)]; return acoustic_scale_ * logprob; } void Decodable::Reset() { - if (frontend_ != nullptr) frontend_->Reset(); - if (nnet_ != nullptr) nnet_->Reset(); + if (nnet_producer_ != nullptr) nnet_producer_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_out_cache_.Resize(0, 0); + framelikelihood_.clear(); } void Decodable::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { kaldi::Timer timer; - nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); + nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score); VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h index dd7b329e..cd498e42 100644 --- a/speechx/speechx/asr/nnet/decodable.h +++ b/speechx/speechx/asr/nnet/decodable.h @@ -13,10 +13,10 @@ // limitations under the License. #include "base/common.h" -#include "frontend/audio/frontend_itf.h" #include "kaldi/decoder/decodable-itf.h" #include "kaldi/matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" +#include "nnet/nnet_producer.h" namespace ppspeech { @@ -24,8 +24,7 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, + explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); // void Init(DecodableOpts config); @@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface { void Reset(); - bool IsInputFinished() const { return frontend_->IsFinished(); } + bool IsInputFinished() const { return nnet_producer_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } - // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: - std::shared_ptr frontend_; - std::shared_ptr nnet_; - - // nnet outputs' cache - kaldi::Matrix nnet_out_cache_; + std::shared_ptr nnet_producer_; // the frame is nnet prob frame rather than audio feature frame // nnet frame subsample the feature frame @@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface { // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_; + std::vector framelikelihood_; kaldi::BaseFloat acoustic_scale_; }; diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc new file mode 100644 index 00000000..3a0c4f18 --- /dev/null +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet_producer.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::BaseFloat; + +NnetProducer::NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend) + : nnet_(nnet), frontend_(frontend) {} + +void NnetProducer::Accept(const kaldi::VectorBase& inputs) { + frontend_->Accept(inputs); + bool result = false; + do { + result = Compute(); + } while (result); +} + +void NnetProducer::Acceptlikelihood( + const kaldi::Matrix& likelihood) { + std::vector prob; + prob.resize(likelihood.NumCols()); + for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { + for (size_t col = 0; col < likelihood.NumCols(); ++col) { + prob[col] = likelihood(idx, col); + cache_.push_back(prob); + } + } +} + +bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + return flag; +} + +bool NnetProducer::Compute() { + Vector features; + if (frontend_ == NULL || frontend_->Read(&features) == false) { + // no feat or frontend_ not init. + VLOG(3) << "no feat avalible"; + return false; + } + CHECK_GE(frontend_->Dim(), 0); + VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; + + NnetOut out; + nnet_->FeedForward(features, frontend_->Dim(), &out); + int32& vocab_dim = out.vocab_dim; + Vector& logprobs = out.logprobs; + size_t nframes = logprobs.Dim() / vocab_dim; + VLOG(2) << "Forward out " << nframes << " decoder frames."; + std::vector logprob(vocab_dim); + // remove later. + for (size_t idx = 0; idx < nframes; ++idx) { + for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { + logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); + } + cache_.push_back(logprob); + } + return true; +} + +void NnetProducer::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h new file mode 100644 index 00000000..65e9116f --- /dev/null +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "base/safe_queue.h" +#include "frontend/audio/frontend_itf.h" +#include "nnet/nnet_itf.h" + +namespace ppspeech { + +class NnetProducer { + public: + explicit NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend = NULL); + + // Feed feats or waves + void Accept(const kaldi::VectorBase& inputs); + + void Acceptlikelihood(const kaldi::Matrix& likelihood); + + // nnet + bool Read(std::vector* nnet_prob); + + bool Empty() const { return cache_.empty(); } + + void SetFinished() { + LOG(INFO) << "set finished"; + // std::unique_lock lock(mutex_); + frontend_->SetFinished(); + + // read the last chunk data + Compute(); + // ready_feed_condition_.notify_one(); + LOG(INFO) << "compute last feats done."; + } + + bool IsFinished() const { return frontend_->IsFinished(); } + + void Reset() { + frontend_->Reset(); + nnet_->Reset(); + VLOG(3) << "feature cache reset: cache size: " << cache_.size(); + cache_.clear(); + } + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score); + + private: + bool Compute(); + + std::shared_ptr frontend_; + std::shared_ptr nnet_; + SafeQueue> cache_; + + DISALLOW_COPY_AND_ASSIGN(NnetProducer); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 05078873..53e2e58d 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -30,6 +30,7 @@ endif() if (USING_U2) set(TEST_BINS u2_recognizer_main + u2_recognizer_thread_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index d1d308eb..ea62ae1a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -27,13 +27,13 @@ using std::vector; U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) : opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - + std::shared_ptr feature_pipeline( + new FeaturePipeline(feature_opts)); std::shared_ptr nnet(new U2Nnet(resource.model_opts)); - - BaseFloat am_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); + decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); decoder_.reset(new CTCPrefixBeamSearch( @@ -49,6 +49,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) void U2Recognizer::Reset() { global_frame_offset_ = 0; + input_finished_ = false; num_frames_ = 0; result_.clear(); @@ -68,7 +69,7 @@ void U2Recognizer::ResetContinuousDecoding() { void U2Recognizer::Accept(const VectorBase& waves) { kaldi::Timer timer; - feature_pipeline_->Accept(waves); + nnet_producer_->Accept(waves); VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() << " samples."; } @@ -210,7 +211,7 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } void U2Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); + nnet_producer_->SetFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 25850863..855d161a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -130,11 +130,11 @@ class U2Recognizer { return !result_.empty() && !result_[0].sentence.empty(); } - int FrameShiftInMs() const { - // one decoder frame length in ms - return decodable_->Nnet()->SubsamplingRate() * - feature_pipeline_->FrameShift(); + // one decoder frame length in ms, todo + return 1; + // return decodable_->Nnet()->SubsamplingRate() * + // feature_pipeline_->FrameShift(); } @@ -149,7 +149,7 @@ class U2Recognizer { // std::shared_ptr resource_; // U2RecognizerResource resource_; - std::shared_ptr feature_pipeline_; + std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc new file mode 100644 index 00000000..e73efef1 --- /dev/null +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/u2_recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +void decode_func(std::shared_ptr recognizer) { + while (!recognizer->IsFinished()) { + recognizer->Decode(); + usleep(100); + } + recognizer->Decode(); + recognizer->Rescoring(); +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_decode_time = 0.0; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); + std::shared_ptr recognizer_ptr( + new ppspeech::U2Recognizer(resource)); + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::thread recognizer_thread(decode_func, recognizer_ptr); + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); + + recognizer_ptr->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer_ptr->SetFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + recognizer_thread.join(); + std::string result = recognizer_ptr->GetFinalResult(); + recognizer_ptr->Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + continue; + } + + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + result_writer.Write(utt, result); + + ++num_done; + } + + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} diff --git a/speechx/speechx/asr/server/CMakeLists.txt b/speechx/speechx/asr/server/CMakeLists.txt index 71b33daa..566b42ee 100644 --- a/speechx/speechx/asr/server/CMakeLists.txt +++ b/speechx/speechx/asr/server/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(websocket) +#add_subdirectory(websocket) diff --git a/speechx/speechx/common/base/common.h b/speechx/speechx/common/base/common.h index 97bff966..2a066ee6 100644 --- a/speechx/speechx/common/base/common.h +++ b/speechx/speechx/common/base/common.h @@ -48,4 +48,4 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" \ No newline at end of file +#include "utils/math.h" diff --git a/speechx/speechx/common/base/safe_queue.h b/speechx/speechx/common/base/safe_queue.h new file mode 100644 index 00000000..25a012af --- /dev/null +++ b/speechx/speechx/common/base/safe_queue.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" + +namespace ppspeech { + +template +class SafeQueue { + public: + explicit SafeQueue(size_t capacity = 0); + void push_back(const T& in); + bool pop(T* out); + bool empty() const { return buffer_.empty(); } + size_t size() const { return buffer_.size(); } + void clear(); + + + private: + std::mutex mutex_; + std::condition_variable condition_; + std::deque buffer_; + size_t capacity_; +}; + +template +SafeQueue::SafeQueue(size_t capacity) : capacity_(capacity) {} + +template +void SafeQueue::push_back(const T& in) { + std::unique_lock lock(mutex_); + if (capacity_ > 0 && buffer_.size() == capacity_) { + condition_.wait(lock, [this] { return capacity_ >= buffer_.size(); }); + } + + buffer_.push_back(in); + condition_.notify_one(); +} + +template +bool SafeQueue::pop(T* out) { + if (buffer_.empty()) { + return false; + } + + std::unique_lock lock(mutex_); + condition_.wait(lock, [this] { return buffer_.size() > 0; }); + *out = std::move(buffer_.front()); + buffer_.pop_front(); + condition_.notify_one(); + return true; +} + +template +void SafeQueue::clear() { + std::unique_lock lock(mutex_); + buffer_.clear(); + condition_.notify_one(); +} +} // namespace ppspeech From acf1d27230bdeb3144dfa88da7843cb22ea0aa9c Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Fri, 30 Dec 2022 15:54:26 +0800 Subject: [PATCH 04/50] [speechx] rm ds2 && rm boost (#2786) * fix openfst download error * add acknowledgments of openfst * refactor directory * clean ctc_decoders dir * add nnet cache && make 2 thread work * do not compile websocket * rm ds2 && rm boost * rm ds2 example --- .pre-commit-config.yaml | 4 +- speechx/CMakeLists.txt | 18 - speechx/build.sh | 17 +- speechx/examples/ds2_ol/README.md | 7 - speechx/examples/ds2_ol/aishell/.gitignore | 3 - speechx/examples/ds2_ol/aishell/README.md | 133 - .../ds2_ol/aishell/local/aishell_train_lms.sh | 71 - .../ds2_ol/aishell/local/run_build_tlg.sh | 145 - .../ds2_ol/aishell/local/split_data.sh | 30 - speechx/examples/ds2_ol/aishell/path.sh | 24 - speechx/examples/ds2_ol/aishell/run.sh | 180 -- speechx/examples/ds2_ol/aishell/run_fbank.sh | 177 -- speechx/examples/ds2_ol/aishell/utils | 1 - speechx/examples/ds2_ol/onnx/.gitignore | 3 - speechx/examples/ds2_ol/onnx/README.md | 57 - .../examples/ds2_ol/onnx/local/infer_check.py | 100 - speechx/examples/ds2_ol/onnx/local/netron.sh | 14 - .../examples/ds2_ol/onnx/local/onnx_clone.sh | 7 - .../ds2_ol/onnx/local/onnx_convert_opset.py | 37 - .../ds2_ol/onnx/local/onnx_infer_shape.py | 2514 ----------------- .../examples/ds2_ol/onnx/local/onnx_opt.sh | 20 - .../ds2_ol/onnx/local/onnx_prune_model.py | 128 - .../ds2_ol/onnx/local/onnx_rename_model.py | 111 - .../ds2_ol/onnx/local/ort_dyanmic_quant.py | 48 - speechx/examples/ds2_ol/onnx/local/ort_opt.py | 45 - speechx/examples/ds2_ol/onnx/local/tonnx.sh | 26 - speechx/examples/ds2_ol/onnx/path.sh | 14 - speechx/examples/ds2_ol/onnx/run.sh | 91 - speechx/examples/ds2_ol/onnx/utils | 1 - speechx/examples/ds2_ol/websocket/.gitignore | 2 - speechx/examples/ds2_ol/websocket/README.md | 78 - speechx/examples/ds2_ol/websocket/path.sh | 14 - .../ds2_ol/websocket/websocket_client.sh | 35 - .../ds2_ol/websocket/websocket_server.sh | 55 - speechx/examples/u2pp_ol/wenetspeech/path.sh | 4 +- speechx/speechx/asr/decoder/CMakeLists.txt | 59 +- .../asr/decoder/ctc_beam_search_decoder.cc | 313 -- .../asr/decoder/ctc_beam_search_decoder.h | 73 - .../decoder/ctc_beam_search_decoder_main.cc | 167 -- .../asr/decoder/ctc_decoders/.gitignore | 9 - .../ctc_decoders/ctc_beam_search_decoder.cpp | 607 ---- .../ctc_decoders/ctc_beam_search_decoder.h | 175 -- .../ctc_decoders/ctc_greedy_decoder.cpp | 61 - .../decoder/ctc_decoders/ctc_greedy_decoder.h | 35 - .../decoder/ctc_decoders/decoder_utils.cpp | 193 -- .../asr/decoder/ctc_decoders/decoder_utils.h | 111 - .../asr/decoder/ctc_decoders/path_trie.cpp | 164 -- .../asr/decoder/ctc_decoders/path_trie.h | 82 - .../asr/decoder/ctc_decoders/scorer.cpp | 232 -- .../speechx/asr/decoder/ctc_decoders/scorer.h | 114 - .../asr/decoder/nnet_logprob_decoder_main.cc | 77 - speechx/speechx/asr/decoder/param.h | 3 +- speechx/speechx/asr/nnet/CMakeLists.txt | 24 +- speechx/speechx/asr/nnet/ds2_nnet.cc | 218 -- speechx/speechx/asr/nnet/ds2_nnet.h | 97 - speechx/speechx/asr/nnet/ds2_nnet_main.cc | 142 - speechx/speechx/asr/nnet/nnet_producer.cc | 1 - speechx/speechx/asr/recognizer/CMakeLists.txt | 50 +- speechx/speechx/asr/recognizer/recognizer.cc | 70 - speechx/speechx/asr/recognizer/recognizer.h | 70 - .../speechx/asr/recognizer/recognizer_main.cc | 105 - speechx/speechx/codelab/CMakeLists.txt | 1 - speechx/speechx/codelab/nnet/CMakeLists.txt | 6 - .../codelab/nnet/ds2_model_test_main.cc | 207 -- .../frontend/audio/cmvn_json2kaldi_main.cc | 46 +- speechx/speechx/common/utils/picojson.h | 1202 ++++++++ 66 files changed, 1265 insertions(+), 7663 deletions(-) delete mode 100644 speechx/examples/ds2_ol/README.md delete mode 100644 speechx/examples/ds2_ol/aishell/.gitignore delete mode 100644 speechx/examples/ds2_ol/aishell/README.md delete mode 100755 speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh delete mode 100755 speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh delete mode 100755 speechx/examples/ds2_ol/aishell/local/split_data.sh delete mode 100755 speechx/examples/ds2_ol/aishell/path.sh delete mode 100755 speechx/examples/ds2_ol/aishell/run.sh delete mode 100755 speechx/examples/ds2_ol/aishell/run_fbank.sh delete mode 120000 speechx/examples/ds2_ol/aishell/utils delete mode 100644 speechx/examples/ds2_ol/onnx/.gitignore delete mode 100644 speechx/examples/ds2_ol/onnx/README.md delete mode 100755 speechx/examples/ds2_ol/onnx/local/infer_check.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/netron.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_clone.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_opt.sh delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/ort_opt.py delete mode 100755 speechx/examples/ds2_ol/onnx/local/tonnx.sh delete mode 100755 speechx/examples/ds2_ol/onnx/path.sh delete mode 100755 speechx/examples/ds2_ol/onnx/run.sh delete mode 120000 speechx/examples/ds2_ol/onnx/utils delete mode 100644 speechx/examples/ds2_ol/websocket/.gitignore delete mode 100644 speechx/examples/ds2_ol/websocket/README.md delete mode 100755 speechx/examples/ds2_ol/websocket/path.sh delete mode 100755 speechx/examples/ds2_ol/websocket/websocket_client.sh delete mode 100755 speechx/examples/ds2_ol/websocket/websocket_server.sh delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/.gitignore delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.h delete mode 100644 speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.h delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet_main.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.h delete mode 100644 speechx/speechx/asr/recognizer/recognizer_main.cc delete mode 100644 speechx/speechx/codelab/nnet/CMakeLists.txt delete mode 100644 speechx/speechx/codelab/nnet/ds2_model_test_main.cc create mode 100644 speechx/speechx/common/utils/picojson.h diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15b842d5..99461947 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,13 +57,13 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 45bf5419..cfce63dd 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -44,9 +44,6 @@ option(TEST_DEBUG "option for debug" OFF) option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) -option(USING_U2 "compile u2 model." ON) -option(USING_DS2 "compile with ds2 model." OFF) - option(USING_GPU "u2 compute on GPU." OFF) ############################################################################### @@ -56,21 +53,6 @@ include(gflags) include(glog) -# boost -# include(boost) # not work -set(boost_SOURCE_DIR ${fc_patch}/boost-src) -set(BOOST_ROOT ${boost_SOURCE_DIR}) -include_directories(${boost_SOURCE_DIR}) -link_directories(${boost_SOURCE_DIR}/stage/lib) - -# Eigen -include(eigen) -find_package(Eigen3 REQUIRED) - -# Kenlm -include(kenlm) -add_dependencies(kenlm eigen boost) - #openblas include(openblas) diff --git a/speechx/build.sh b/speechx/build.sh index 7655f963..94d250f5 100755 --- a/speechx/build.sh +++ b/speechx/build.sh @@ -4,20 +4,5 @@ set -xe # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -boost_SOURCE_DIR=$PWD/fc_patch/boost-src -if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz - tar xzfv boost_1_75_0.tar.gz - mkdir -p $PWD/fc_patch - mv boost_1_75_0 ${boost_SOURCE_DIR} - cd ${boost_SOURCE_DIR} - bash ./bootstrap.sh - ./b2 - cd - - echo -e "\n" -fi - -#rm -rf build -mkdir -p build - -cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} +cmake -B build cmake --build build -j diff --git a/speechx/examples/ds2_ol/README.md b/speechx/examples/ds2_ol/README.md deleted file mode 100644 index d1da96cc..00000000 --- a/speechx/examples/ds2_ol/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Deepspeech2 Streaming ASR - -## Examples - -* `websocket` - Streaming ASR with websocket for deepspeech2_aishell. -* `aishell` - Streaming Decoding under aishell dataset, for local WER test. -* `onnx` - Example to convert deepspeech2 to onnx format. diff --git a/speechx/examples/ds2_ol/aishell/.gitignore b/speechx/examples/ds2_ol/aishell/.gitignore deleted file mode 100644 index 68f993b4..00000000 --- a/speechx/examples/ds2_ol/aishell/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -data -exp -aishell_* diff --git a/speechx/examples/ds2_ol/aishell/README.md b/speechx/examples/ds2_ol/aishell/README.md deleted file mode 100644 index 2ee0bbca..00000000 --- a/speechx/examples/ds2_ol/aishell/README.md +++ /dev/null @@ -1,133 +0,0 @@ -# Aishell - Deepspeech2 Streaming - -> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). - -A C++ deployment example for using the deepspeech2 model to recognize `wav` and compute `CER`. We using AISHELL-1 as test data. - -## Source path.sh - -```bash -. path.sh -``` - -SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`. - -## Recognize with linear feature - -```bash -bash run.sh -``` - -`run.sh` has multi stage, for details please see `run.sh`: - -1. donwload dataset, model and lm -2. convert cmvn format and compute feature -3. decode w/o lm by feature -4. decode w/ ngram lm by feature -5. decode w/ TLG graph by feature -6. recognize w/ TLG graph by wav input - -### Recognize with `.scp` file for wav - -This sciprt using `recognizer_main` to recognize wav file. - -The input is `scp` file which look like this: -```text -# head data/split1/1/aishell_test.scp -BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav -BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav -... -BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav -``` - -If you want to recognize one wav, you can make `scp` file like this: -```text -key path/to/wav/file -``` - -Then specify `--wav_rspecifier=` param for `recognizer_main` bin. For other flags meaning, please see `help`: -```bash -recognizer_main --help -``` - -For the exmaple to using `recognizer_main` please see `run.sh`. - - -### CTC Prefix Beam Search w/o LM - -``` -Overall -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465 -Mandarin -> 16.14 % N=104612 C=88190 S=16110 D=312 I=465 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC Prefix Beam Search w/ LM - -LM: zh_giga.no_cna_cmn.prune01244.klm -``` -Overall -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 -Mandarin -> 7.86 % N=104768 C=96865 S=7573 D=330 I=327 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC TLG WFST - -LM: [aishell train](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/aishell_graph.zip) ---acoustic_scale=1.2 -``` -Overall -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1819 -Mandarin -> 11.14 % N=103017 C=93363 S=9583 D=71 I=1818 -Other -> 0.00 % N=0 C=0 S=0 D=0 I=1 -``` - -LM: [wenetspeech](http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/aishell/wenetspeech_graph.zip) ---acoustic_scale=1.5 -``` -Overall -> 10.93 % N=104765 C=93410 S=9780 D=1575 I=95 -Mandarin -> 10.93 % N=104762 C=93410 S=9779 D=1573 I=95 -Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 -``` - -## Recognize with fbank feature - -This script is same to `run.sh`, but using fbank feature. - -```bash -bash run_fbank.sh -``` - -### CTC Prefix Beam Search w/o LM - -``` -Overall -> 10.44 % N=104765 C=94194 S=10174 D=397 I=369 -Mandarin -> 10.44 % N=104762 C=94194 S=10171 D=397 I=369 -Other -> 100.00 % N=3 C=0 S=3 D=0 I=0 -``` - -### CTC Prefix Beam Search w/ LM - -LM: zh_giga.no_cna_cmn.prune01244.klm - -``` -Overall -> 5.82 % N=104765 C=99386 S=4944 D=435 I=720 -Mandarin -> 5.82 % N=104762 C=99386 S=4941 D=435 I=720 -English -> 0.00 % N=0 C=0 S=0 D=0 I=0 -``` - -### CTC TLG WFST - -LM: [aishell train](https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip) -``` -Overall -> 9.58 % N=104765 C=94817 S=4326 D=5622 I=84 -Mandarin -> 9.57 % N=104762 C=94817 S=4325 D=5620 I=84 -Other -> 100.00 % N=3 C=0 S=1 D=2 I=0 -``` - -## Build TLG WFST graph - -The script is for building TLG wfst graph, depending on `srilm`, please make sure it is installed. -For more information please see the script below. - -```bash - bash ./local/run_build_tlg.sh -``` diff --git a/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh b/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh deleted file mode 100755 index 544a1f59..00000000 --- a/speechx/examples/ds2_ol/aishell/local/aishell_train_lms.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# To be run from one directory above this script. -. ./path.sh - -nj=40 -text=data/local/lm/text -lexicon=data/local/dict/lexicon.txt - -for f in "$text" "$lexicon"; do - [ ! -f $x ] && echo "$0: No such file $f" && exit 1; -done - -# Check SRILM tools -if ! which ngram-count > /dev/null; then - echo "srilm tools are not found, please download it and install it from: " - echo "http://www.speech.sri.com/projects/srilm/download.html" - echo "Then add the tools to your PATH" - exit 1 -fi - -# This script takes no arguments. It assumes you have already run -# aishell_data_prep.sh. -# It takes as input the files -# data/local/lm/text -# data/local/dict/lexicon.txt -dir=data/local/lm -mkdir -p $dir - -cleantext=$dir/text.no_oov - -# oov to -# lexicon line: word char0 ... charn -# text line: utt word0 ... wordn -> line: word0 ... wordn -text_dir=$(dirname $text) -split_name=$(basename $text) -./local/split_data.sh $text_dir $text $split_name $nj - -utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \ - cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } - {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ - \> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1; -cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext - -# compute word counts, sort in descending order -# line: count word -cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \ - sort --parallel=`nproc` -nr > $dir/word.counts || exit 1; - -# Get counts from acoustic training transcripts, and add one-count -# for each word in the lexicon (but not silence, we don't want it -# in the LM-- we'll add it optionally later). -cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ - cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ - sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1; - -# word with -cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist - -# hold out to compute ppl -heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results - -mkdir -p $dir -cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout -cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train - -ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ - -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa -ngram -lm $dir/lm.arpa -ppl $dir/heldout \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh b/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh deleted file mode 100755 index 07f47c7e..00000000 --- a/speechx/examples/ds2_ol/aishell/local/run_build_tlg.sh +++ /dev/null @@ -1,145 +0,0 @@ -#!/bin/bash -set -eo pipefail - -. path.sh - -# attention, please replace the vocab is only for this script. -# different acustic model has different vocab -ckpt_dir=data/fbank_model -unit=$ckpt_dir/data/lang_char/vocab.txt # vocab file, line: char/spm_pice -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ - -stage=-1 -stop_stage=100 -corpus=aishell -lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt -text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt - -. utils/parse_options.sh - -data=$PWD/data -mkdir -p $data - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - if [ ! -f $data/speech.ngram.zh.tar.gz ];then - # download ngram - pushd $data - wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz - tar xvzf speech.ngram.zh.tar.gz - popd - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - # download model - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - popd - fi -fi - -if [ ! -f $unit ]; then - echo "$0: No such file $unit" - exit 1; -fi - -if ! which ngram-count; then - # need srilm install - pushd $MAIN_ROOT/tools - make srilm.done - popd -fi - -mkdir -p data/local/dict -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # Prepare dict - # line: char/spm_pices - cp $unit data/local/dict/units.txt - - if [ ! -f $lexicon ];then - utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon - echo "Generate $lexicon from $text" - fi - - # filter by vocab - # line: word ph0 ... phn -> line: word char0 ... charn - utils/fst/prepare_dict.py \ - --unit_file $unit \ - --in_lexicon ${lexicon} \ - --out_lexicon data/local/dict/lexicon.txt -fi - -lm=data/local/lm -mkdir -p $lm - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # Train ngram lm - cp $text $lm/text - local/aishell_train_lms.sh - echo "build LM done." -fi - -# build TLG -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # build T & L - utils/fst/compile_lexicon_token_fst.sh \ - data/local/dict data/local/tmp data/local/lang - - # build G & TLG - utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; - -fi - -aishell_wav_scp=aishell_test.scp -nj=40 -cmvn=$data/cmvn_fbank.ark -wfst=$data/lang_test - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - if [ ! -d $data/test ]; then - # download test dataset - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - # convert cmvn format - cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn -fi - -wer=aishell_wer -label_file=aishell_result -export GLOG_logtostderr=1 - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # recognize w/ TLG graph - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/check_tlg.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --streaming_chunk=30 \ - --use_fbank=true \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_check_tlg - - cat $data/split${nj}/*/result_check_tlg > $exp/${label_file}_check_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_check_tlg > $exp/${wer}.check_tlg - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.check_tlg" -fi - -exit 0 diff --git a/speechx/examples/ds2_ol/aishell/local/split_data.sh b/speechx/examples/ds2_ol/aishell/local/split_data.sh deleted file mode 100755 index 2af6fc5a..00000000 --- a/speechx/examples/ds2_ol/aishell/local/split_data.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash - -set -eo pipefail - -data=$1 -scp=$2 -split_name=$3 -numsplit=$4 - -# save in $data/split{n} -# $scp to split -# - -if [[ ! $numsplit -gt 0 ]]; then - echo "Invalid num-split argument"; - exit 1; -fi - -directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done) -scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done) - -# if this mkdir fails due to argument-list being too long, iterate. -if ! mkdir -p $directories >&/dev/null; then - for n in `seq $numsplit`; do - mkdir -p $data/split${numsplit}/$n - done -fi - -echo "utils/split_scp.pl $scp $scp_splits" -utils/split_scp.pl $scp $scp_splits diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh deleted file mode 100755 index 6e803935..00000000 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ /dev/null @@ -1,24 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -# openfst bin & kaldi bin -KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ -OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src - -# srilm -export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs -export SRILM=${MAIN_ROOT}/tools/srilm - -SPEECHX_BIN=$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh deleted file mode 100755 index 49438cb2..00000000 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash -set -x -set -e - -. path.sh - -nj=40 -stage=0 -stop_stage=100 - -. utils/parse_options.sh - -# 1. compile -if [ ! -d ${SPEECHX_BUILD} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -ckpt_dir=$data/model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - -# output -mkdir -p exp -exp=$PWD/exp - -aishell_wav_scp=aishell_test.scp -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - if [ ! -d $data/test ]; then - # donwload dataset - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - # download model - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - popd - fi - - lm=$data/zh_giga.no_cna_cmn.prune01244.klm - if [ ! -f $lm ]; then - # download kenlm bin - pushd $data - wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm - popd - fi -fi - -# 3. make feature -text=$data/test/text -label_file=./aishell_result -wer=./aishell_wer - -export GLOG_logtostderr=1 - - -cmvn=$data/cmvn.ark -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # 3. convert cmvn format and compute linear feat - cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ - compute_linear_spectrogram_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --feature_wspecifier=ark,scp:$data/split${nj}/JOB/feat.ark,$data/split${nj}/JOB/feat.scp \ - --cmvn_file=$cmvn \ - echo "feature make have finished!!!" -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # decode w/o lm - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result - - cat $data/split${nj}/*/result > $exp/${label_file} - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} - echo "ctc-prefix-beam-search-decoder-ol without lm has finished!!!" - echo "please checkout in ${exp}/${wer}" - tail -n 7 $exp/${wer} -fi - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # decode w/ ngram lm with feature input - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm - - cat $data/split${nj}/*/result_lm > $exp/${label_file}_lm - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm - echo "ctc-prefix-beam-search-decoder-ol with lm test has finished!!!" - echo "please checkout in ${exp}/${wer}.lm" - tail -n 7 $exp/${wer}.lm -fi - -wfst=$data/wfst/ -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - mkdir -p $wfst - if [ ! -f $wfst/aishell_graph.zip ]; then - # download TLG graph - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip aishell_graph.zip - mv aishell_graph/* $wfst - popd - fi -fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # decoder w/ TLG graph with feature input - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ - ctc_tlg_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --nnet_decoder_chunk=8 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg - - cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg - echo "wfst-decoder-ol have finished!!!" - echo "please checkout in ${exp}/${wer}.tlg" - tail -n 7 $exp/${wer}.tlg -fi - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - # recognize from wav file w/ TLG graph - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --nnet_decoder_chunk=8 \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer - - cat $data/split${nj}/*/result_recognizer > $exp/${label_file}_recognizer - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.recognizer" - tail -n 7 $exp/${wer}.recognizer -fi \ No newline at end of file diff --git a/speechx/examples/ds2_ol/aishell/run_fbank.sh b/speechx/examples/ds2_ol/aishell/run_fbank.sh deleted file mode 100755 index b93d6944..00000000 --- a/speechx/examples/ds2_ol/aishell/run_fbank.sh +++ /dev/null @@ -1,177 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -nj=40 -stage=0 -stop_stage=5 - -. utils/parse_options.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -ckpt_dir=$data/fbank_model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - -# output -mkdir -p exp -exp=$PWD/exp - -lm=$data/zh_giga.no_cna_cmn.prune01244.klm -aishell_wav_scp=aishell_test.scp -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp - fi - - if [ ! -f $ckpt_dir/data/mean_std.json ]; then - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - tar xzfv WIP1_asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz - popd - fi - - if [ ! -f $lm ]; then - pushd $data - wget -c https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm - popd - fi -fi - -# 3. make feature -text=$data/test/text -label_file=./aishell_result_fbank -wer=./aishell_wer_fbank - -export GLOG_logtostderr=1 - - -cmvn=$data/cmvn_fbank.ark -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # 3. convert cmvn format and compute fbank feat - cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn --binary=false - - ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj - - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ - compute_fbank_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank_feat.ark,$data/split${nj}/JOB/fbank_feat.scp \ - --cmvn_file=$cmvn \ - --streaming_chunk=36 -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # decode w/ lm by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank - - cat $data/split${nj}/*/result_fbank > $exp/${label_file} - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer} - tail -n 7 $exp/${wer} -fi - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # decode with ngram lm by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ - ctc_beam_search_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --dict_file=$vocb_dir/vocab.txt \ - --lm_path=$lm \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/fbank_result_lm - - cat $data/split${nj}/*/fbank_result_lm > $exp/${label_file}_lm - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_lm > $exp/${wer}.lm - tail -n 7 $exp/${wer}.lm -fi - -wfst=$data/wfst_fbank/ -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - mkdir -p $wfst - if [ ! -f $wfst/aishell_graph2.zip ]; then - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph2.zip - unzip aishell_graph2.zip - mv aishell_graph2/* $wfst - popd - fi -fi - -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - # decode w/ TLG graph by feature - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ - ctc_tlg_decoder_main \ - --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg - - cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_tlg > $exp/${wer}.tlg - echo "wfst-decoder-ol have finished!!!" - echo "please checkout in ${exp}/${wer}.tlg" - tail -n 7 $exp/${wer}.tlg -fi - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - # recgonize w/ TLG graph by wav - utils/run.pl JOB=1:$nj $data/split${nj}/JOB/fbank_recognizer.log \ - recognizer_main \ - --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_5.jit.pdmodel \ - --use_fbank=true \ - --param_path=$model_dir/avg_5.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --model_cache_shapes="5-1-2048,5-1-2048" \ - --nnet_decoder_chunk=8 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 \ - --result_wspecifier=ark,t:$data/split${nj}/JOB/result_fbank_recognizer - - cat $data/split${nj}/*/result_fbank_recognizer > $exp/${label_file}_recognizer - utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file}_recognizer > $exp/${wer}.recognizer - echo "recognizer test have finished!!!" - echo "please checkout in ${exp}/${wer}.recognizer" - tail -n 7 $exp/${wer}.recognizer -fi diff --git a/speechx/examples/ds2_ol/aishell/utils b/speechx/examples/ds2_ol/aishell/utils deleted file mode 120000 index c2519a9d..00000000 --- a/speechx/examples/ds2_ol/aishell/utils +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/ \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/.gitignore b/speechx/examples/ds2_ol/onnx/.gitignore deleted file mode 100644 index f862f73e..00000000 --- a/speechx/examples/ds2_ol/onnx/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -data -log -exp diff --git a/speechx/examples/ds2_ol/onnx/README.md b/speechx/examples/ds2_ol/onnx/README.md deleted file mode 100644 index b98b74b6..00000000 --- a/speechx/examples/ds2_ol/onnx/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Convert DeepSpeech2 model to ONNX format - -> We recommend using U2/U2++ model instead of DS2, please see [here](../../u2pp_ol/wenetspeech/). - -This example demonstrate converting ds2 model to ONNX fromat. - -Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct. - -The example test with these packages installed: -``` -paddle2onnx 0.9.8 # develop 62c5424e22cd93968dc831216fc9e0f0fce3d819 -paddleaudio 0.2.1 -paddlefsl 1.1.0 -paddlenlp 2.2.6 -paddlepaddle-gpu 2.2.2 -paddlespeech 0.0.0 # develop -paddlespeech-ctcdecoders 0.2.0 -paddlespeech-feat 0.1.0 -onnx 1.11.0 -onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape -onnxoptimizer 0.2.7 -onnxruntime 1.11.0 -``` - - -## Using - -``` -bash run.sh --stage 0 --stop_stage 5 -``` - -1. convert deepspeech2 model to ONNX, using Paddle2ONNX. -2. check paddleinference and onnxruntime output equal. -3. optimize onnx model -4. check paddleinference and optimized onnxruntime output equal. -5. quantize onnx model -6. check paddleinference and optimized onnxruntime output equal. - -For more details please see `run.sh`. - -## Outputs -The optimized onnx model is `exp/model.opt.onnx`, quanted model is `exp/model.optset11.quant.onnx`. - - -## [Results](https://github.com/PaddlePaddle/PaddleSpeech/wiki/ASR-Benchmark#streaming-asr) - -机器硬件:`CPU:Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz` -测试脚本:`Streaming Server` - -Acoustic Model | Model Size | enigne | dedoding_method | ctc_weight | decoding_chunk_size | num_decoding_left_chunk | RTF | -|:-------------:| :-----: | :-----: | :------------:| :-----: | :-----: | :-----: |:-----:| -| deepspeech2online_wenetspeech | 659MB | infernece | ctc_prefix_beam_search | - | 1 | - | 1.9108175171428279(utts=80) | -| deepspeech2online_wenetspeech | 659MB | onnx | ctc_prefix_beam_search | - | 1 | - | 0.5617182449999291 (utts=80) | -| deepspeech2online_wenetspeech | 166MB | onnx quant | ctc_prefix_beam_search | - | 1 | - | 0.44507715475808385 (utts=80) | - -> quant 和机器有关,不是所有机器都支持。ONNX quant测试机器指令集支持: -> Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology eagerfpu pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl diff --git a/speechx/examples/ds2_ol/onnx/local/infer_check.py b/speechx/examples/ds2_ol/onnx/local/infer_check.py deleted file mode 100755 index f821baa1..00000000 --- a/speechx/examples/ds2_ol/onnx/local/infer_check.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import os -import pickle - -import numpy as np -import onnxruntime -import paddle - - -def parse_args(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - '--input_file', - type=str, - default="static_ds2online_inputs.pickle", - help="aishell ds2 input data file. For wenetspeech, we only feed for infer model", - ) - parser.add_argument( - '--model_type', - type=str, - default="aishell", - help="aishell(1024) or wenetspeech(2048)", ) - parser.add_argument( - '--model_dir', type=str, default=".", help="paddle model dir.") - parser.add_argument( - '--model_prefix', - type=str, - default="avg_1.jit", - help="paddle model prefix.") - parser.add_argument( - '--onnx_model', - type=str, - default='./model.old.onnx', - help="onnx model.") - - return parser.parse_args() - - -if __name__ == '__main__': - FLAGS = parse_args() - - # input and output - with open(FLAGS.input_file, 'rb') as f: - iodict = pickle.load(f) - print(iodict.keys()) - - audio_chunk = iodict['audio_chunk'] - audio_chunk_lens = iodict['audio_chunk_lens'] - chunk_state_h_box = iodict['chunk_state_h_box'] - chunk_state_c_box = iodict['chunk_state_c_bos'] - print("raw state shape: ", chunk_state_c_box.shape) - - if FLAGS.model_type == 'wenetspeech': - chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1) - chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1) - print("state shape: ", chunk_state_c_box.shape) - - # paddle - model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix)) - res_chunk, res_lens, chunk_state_h, chunk_state_c = model( - paddle.to_tensor(audio_chunk), - paddle.to_tensor(audio_chunk_lens), - paddle.to_tensor(chunk_state_h_box), - paddle.to_tensor(chunk_state_c_box), ) - - # onnxruntime - options = onnxruntime.SessionOptions() - options.enable_profiling = True - sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options) - ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run( - ['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], { - "audio_chunk": audio_chunk, - "audio_chunk_lens": audio_chunk_lens, - "chunk_state_h_box": chunk_state_h_box, - "chunk_state_c_box": chunk_state_c_box - }) - - print(sess.end_profiling()) - - # assert paddle equal ort - print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6)) - print(np.allclose(ort_res_lens, res_lens, atol=1e-6)) - - if FLAGS.model_type == 'aishell': - print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6)) - print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6)) diff --git a/speechx/examples/ds2_ol/onnx/local/netron.sh b/speechx/examples/ds2_ol/onnx/local/netron.sh deleted file mode 100755 index 6dd9a39c..00000000 --- a/speechx/examples/ds2_ol/onnx/local/netron.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -# show model - -if [ $# != 1 ];then - echo "usage: $0 model_path" - exit 1 -fi - - -file=$1 - -pip install netron -netron -p 8082 --host $(hostname -i) $file \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh b/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh deleted file mode 100755 index bce22dbc..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_clone.sh +++ /dev/null @@ -1,7 +0,0 @@ - -#!/bin/bash - -# clone onnx repos -git clone https://github.com/onnx/onnx.git -git clone https://github.com/microsoft/onnxruntime.git -git clone https://github.com/PaddlePaddle/Paddle2ONNX.git \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py b/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py deleted file mode 100755 index 00b5cf77..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_convert_opset.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -import onnx -from onnx import version_converter - -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog=__doc__) - parser.add_argument( - "--model-file", type=str, required=True, help='path/to/the/model.onnx.') - parser.add_argument( - "--save-model", - type=str, - required=True, - help='path/to/saved/model.onnx.') - # Models must be opset10 or higher to be quantized. - parser.add_argument( - "--target-opset", type=int, default=11, help='path/to/the/model.onnx.') - - args = parser.parse_args() - - print(f"to opset: {args.target_opset}") - - # Preprocessing: load the model to be converted. - model_path = args.model_file - original_model = onnx.load(model_path) - - # print('The model before conversion:\n{}'.format(original_model)) - - # A full list of supported adapters can be found here: - # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 - # Apply the version conversion on the original model - converted_model = version_converter.convert_version(original_model, - args.target_opset) - - # print('The model after conversion:\n{}'.format(converted_model)) - onnx.save(converted_model, args.save_model) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py deleted file mode 100755 index c53e9ec9..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py +++ /dev/null @@ -1,2514 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# flake8: noqa -import argparse -import logging - -import numpy as np -import onnx -import sympy -from onnx import helper -from onnx import numpy_helper -from onnx import shape_inference -from packaging import version -assert version.parse(onnx.__version__) >= version.parse("1.8.0") - -logger = logging.getLogger(__name__) - - -def get_attribute(node, attr_name, default_value=None): - found = [attr for attr in node.attribute if attr.name == attr_name] - if found: - return helper.get_attribute_value(found[0]) - return default_value - - -def get_dim_from_proto(dim): - return getattr(dim, dim.WhichOneof('value')) if type( - dim.WhichOneof('value')) == str else None - - -def is_sequence(type_proto): - cls_type = type_proto.WhichOneof('value') - assert cls_type in ['tensor_type', 'sequence_type'] - return cls_type == 'sequence_type' - - -def get_shape_from_type_proto(type_proto): - assert not is_sequence(type_proto) - if type_proto.tensor_type.HasField('shape'): - return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] - else: - return None # note no shape is different from shape without dim (scalar) - - -def get_shape_from_value_info(vi): - cls_type = vi.type.WhichOneof('value') - if cls_type is None: - return None - if is_sequence(vi.type): - if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'): - return get_shape_from_type_proto(vi.type.sequence_type.elem_type) - else: - return None - else: - return get_shape_from_type_proto(vi.type) - - -def make_named_value_info(name): - vi = onnx.ValueInfoProto() - vi.name = name - return vi - - -def get_shape_from_sympy_shape(sympy_shape): - return [ - None if i is None else (int(i) if is_literal(i) else str(i)) - for i in sympy_shape - ] - - -def is_literal(dim): - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr( - dim, 'is_number') and dim.is_number) - - -def handle_negative_axis(axis, rank): - assert axis < rank and axis >= -rank - return axis if axis >= 0 else rank + axis - - -def get_opset(mp, domain=None): - domain = domain or ['', 'onnx', 'ai.onnx'] - if type(domain) != list: - domain = [domain] - for opset in mp.opset_import: - if opset.domain in domain: - return opset.version - - return None - - -def as_scalar(x): - if type(x) == list: - assert len(x) == 1 - return x[0] - elif type(x) == np.ndarray: - return x.item() - else: - return x - - -def as_list(x, keep_none): - if type(x) == list: - return x - elif type(x) == np.ndarray: - return list(x) - elif keep_none and x is None: - return None - else: - return [x] - - -def sympy_reduce_product(x): - if type(x) == list: - value = sympy.Integer(1) - for v in x: - value = value * v - else: - value = x - return value - - -class SymbolicShapeInference: - def __init__(self, - int_max, - auto_merge, - guess_output_rank, - verbose, - prefix=''): - self.dispatcher_ = { - 'Add': - self._infer_symbolic_compute_ops, - 'ArrayFeatureExtractor': - self._infer_ArrayFeatureExtractor, - 'AveragePool': - self._infer_Pool, - 'BatchNormalization': - self._infer_BatchNormalization, - 'Cast': - self._infer_Cast, - 'CategoryMapper': - self._infer_CategoryMapper, - 'Compress': - self._infer_Compress, - 'Concat': - self._infer_Concat, - 'ConcatFromSequence': - self._infer_ConcatFromSequence, - 'Constant': - self._infer_Constant, - 'ConstantOfShape': - self._infer_ConstantOfShape, - 'Conv': - self._infer_Conv, - 'CumSum': - self._pass_on_shape_and_type, - 'Div': - self._infer_symbolic_compute_ops, - 'Einsum': - self._infer_Einsum, - 'Expand': - self._infer_Expand, - 'Equal': - self._infer_symbolic_compute_ops, - 'Floor': - self._infer_symbolic_compute_ops, - 'Gather': - self._infer_Gather, - 'GatherElements': - self._infer_GatherElements, - 'GatherND': - self._infer_GatherND, - 'Gelu': - self._pass_on_shape_and_type, - 'If': - self._infer_If, - 'Loop': - self._infer_Loop, - 'MatMul': - self._infer_MatMul, - 'MatMulInteger16': - self._infer_MatMulInteger, - 'MaxPool': - self._infer_Pool, - 'Max': - self._infer_symbolic_compute_ops, - 'Min': - self._infer_symbolic_compute_ops, - 'Mul': - self._infer_symbolic_compute_ops, - 'NonMaxSuppression': - self._infer_NonMaxSuppression, - 'NonZero': - self._infer_NonZero, - 'OneHot': - self._infer_OneHot, - 'Pad': - self._infer_Pad, - 'Range': - self._infer_Range, - 'Reciprocal': - self._pass_on_shape_and_type, - 'ReduceSum': - self._infer_ReduceSum, - 'ReduceProd': - self._infer_ReduceProd, - 'Reshape': - self._infer_Reshape, - 'Resize': - self._infer_Resize, - 'Round': - self._pass_on_shape_and_type, - 'Scan': - self._infer_Scan, - 'ScatterElements': - self._infer_ScatterElements, - 'SequenceAt': - self._infer_SequenceAt, - 'SequenceInsert': - self._infer_SequenceInsert, - 'Shape': - self._infer_Shape, - 'Size': - self._infer_Size, - 'Slice': - self._infer_Slice, - 'SoftmaxCrossEntropyLoss': - self._infer_SoftmaxCrossEntropyLoss, - 'SoftmaxCrossEntropyLossInternal': - self._infer_SoftmaxCrossEntropyLoss, - 'NegativeLogLikelihoodLossInternal': - self._infer_SoftmaxCrossEntropyLoss, - 'Split': - self._infer_Split, - 'SplitToSequence': - self._infer_SplitToSequence, - 'Squeeze': - self._infer_Squeeze, - 'Sub': - self._infer_symbolic_compute_ops, - 'Tile': - self._infer_Tile, - 'TopK': - self._infer_TopK, - 'Transpose': - self._infer_Transpose, - 'Unsqueeze': - self._infer_Unsqueeze, - 'Where': - self._infer_symbolic_compute_ops, - 'ZipMap': - self._infer_ZipMap, - 'Neg': - self._infer_symbolic_compute_ops, - # contrib ops: - 'Attention': - self._infer_Attention, - 'BiasGelu': - self._infer_BiasGelu, - 'EmbedLayerNormalization': - self._infer_EmbedLayerNormalization, - 'FastGelu': - self._infer_FastGelu, - 'Gelu': - self._infer_Gelu, - 'LayerNormalization': - self._infer_LayerNormalization, - 'LongformerAttention': - self._infer_LongformerAttention, - 'PythonOp': - self._infer_PythonOp, - 'SkipLayerNormalization': - self._infer_SkipLayerNormalization - } - self.aten_op_dispatcher_ = { - 'aten::embedding': self._infer_Gather, - 'aten::bitwise_or': self._infer_aten_bitwise_or, - 'aten::diagonal': self._infer_aten_diagonal, - 'aten::max_pool2d_with_indices': self._infer_aten_pool2d, - 'aten::multinomial': self._infer_aten_multinomial, - 'aten::unfold': self._infer_aten_unfold, - 'aten::argmax': self._infer_aten_argmax, - 'aten::avg_pool2d': self._infer_aten_pool2d, - 'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d, - 'aten::binary_cross_entropy_with_logits': self._infer_aten_bce, - 'aten::numpy_T': self._infer_Transpose, - } - self.run_ = True - self.suggested_merge_ = {} - self.symbolic_dims_ = {} - self.input_symbols_ = {} - self.auto_merge_ = auto_merge - self.guess_output_rank_ = guess_output_rank - self.verbose_ = verbose - self.int_max_ = int_max - self.subgraph_id_ = 0 - self.prefix_ = prefix - - def _add_suggested_merge(self, symbols, apply=False): - assert all([(type(s) == str and s in self.symbolic_dims_) or - is_literal(s) for s in symbols]) - symbols = set(symbols) - for k, v in self.suggested_merge_.items(): - if k in symbols: - symbols.remove(k) - symbols.add(v) - map_to = None - # if there is literal, map to it first - for s in symbols: - if is_literal(s): - map_to = s - break - # when no literals, map to input symbolic dims, then existing symbolic dims - if map_to is None: - for s in symbols: - if s in self.input_symbols_: - map_to = s - break - if map_to is None: - for s in symbols: - if type(self.symbolic_dims_[s]) == sympy.Symbol: - map_to = s - break - # when nothing to map to, use the shorter one - if map_to is None: - if self.verbose_ > 0: - logger.warning( - 'Potential unsafe merge between symbolic expressions: ({})'. - format(','.join(symbols))) - symbols_list = list(symbols) - lens = [len(s) for s in symbols_list] - map_to = symbols_list[lens.index(min(lens))] - symbols.remove(map_to) - - for s in symbols: - if s == map_to: - continue - if is_literal(map_to) and is_literal(s): - assert int(map_to) == int(s) - self.suggested_merge_[s] = int(map_to) if is_literal( - map_to) else map_to - for k, v in self.suggested_merge_.items(): - if v == s: - self.suggested_merge_[k] = map_to - if apply and self.auto_merge_: - self._apply_suggested_merge() - - def _apply_suggested_merge(self, graph_input_only=False): - if not self.suggested_merge_: - return - for i in list(self.out_mp_.graph.input) + ( - [] if graph_input_only else list(self.out_mp_.graph.value_info)): - for d in i.type.tensor_type.shape.dim: - if d.dim_param in self.suggested_merge_: - v = self.suggested_merge_[d.dim_param] - if is_literal(v): - d.dim_value = int(v) - else: - d.dim_param = v - - def _preprocess(self, in_mp): - self.out_mp_ = onnx.ModelProto() - self.out_mp_.CopyFrom(in_mp) - self.graph_inputs_ = dict( - [(i.name, i) for i in list(self.out_mp_.graph.input)]) - self.initializers_ = dict( - [(i.name, i) for i in self.out_mp_.graph.initializer]) - self.known_vi_ = dict( - [(i.name, i) for i in list(self.out_mp_.graph.input)]) - self.known_vi_.update( - dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, - list(i.dims))) - for i in self.out_mp_.graph.initializer])) - - def _merge_symbols(self, dims): - if not all([type(d) == str for d in dims]): - if self.auto_merge_: - unique_dims = list(set(dims)) - is_int = [is_literal(d) for d in unique_dims] - assert sum( - is_int - ) <= 1 # if there are more than 1 unique ints, something is wrong - if sum(is_int) == 1: - int_dim = is_int.index(1) - if self.verbose_ > 0: - logger.debug('dim {} has been merged with value {}'. - format(unique_dims[:int_dim] + unique_dims[ - int_dim + 1:], unique_dims[int_dim])) - self._check_merged_dims(unique_dims, allow_broadcast=False) - return unique_dims[int_dim] - else: - if self.verbose_ > 0: - logger.debug('dim {} has been mergd with dim {}'.format( - unique_dims[1:], unique_dims[0])) - return dims[0] - else: - return None - if all([d == dims[0] for d in dims]): - return dims[0] - merged = [ - self.suggested_merge_[d] if d in self.suggested_merge_ else d - for d in dims - ] - if all([d == merged[0] for d in merged]): - assert merged[0] in self.symbolic_dims_ - return merged[0] - else: - return None - - # broadcast from right to left, and merge symbolic dims if needed - def _broadcast_shapes(self, shape1, shape2): - new_shape = [] - rank1 = len(shape1) - rank2 = len(shape2) - new_rank = max(rank1, rank2) - for i in range(new_rank): - dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 - dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 - if dim1 == 1 or dim1 == dim2: - new_dim = dim2 - elif dim2 == 1: - new_dim = dim1 - else: - new_dim = self._merge_symbols([dim1, dim2]) - if not new_dim: - # warning about unsupported broadcast when not auto merge - # note that auto merge has the risk of incorrectly merge symbols while one of them being 1 - # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b' - if self.auto_merge_: - self._add_suggested_merge([dim1, dim2], apply=True) - else: - logger.warning('unsupported broadcast between ' + str( - dim1) + ' ' + str(dim2)) - new_shape = [new_dim] + new_shape - return new_shape - - def _get_shape(self, node, idx): - name = node.input[idx] - if name in self.known_vi_: - vi = self.known_vi_[name] - return get_shape_from_value_info(vi) - else: - assert name in self.initializers_ - return list(self.initializers_[name].dims) - - def _get_shape_rank(self, node, idx): - return len(self._get_shape(node, idx)) - - def _get_sympy_shape(self, node, idx): - sympy_shape = [] - for d in self._get_shape(node, idx): - if type(d) == str: - sympy_shape.append(self.symbolic_dims_[d] if d in - self.symbolic_dims_ else sympy.Symbol( - d, integer=True, nonnegative=True)) - else: - assert None != d - sympy_shape.append(d) - return sympy_shape - - def _get_value(self, node, idx): - name = node.input[idx] - assert name in self.sympy_data_ or name in self.initializers_ - return self.sympy_data_[ - name] if name in self.sympy_data_ else numpy_helper.to_array( - self.initializers_[name]) - - def _try_get_value(self, node, idx): - if idx >= len(node.input): - return None - name = node.input[idx] - if name in self.sympy_data_ or name in self.initializers_: - return self._get_value(node, idx) - return None - - def _update_computed_dims(self, new_sympy_shape): - for i, new_dim in enumerate(new_sympy_shape): - if not is_literal(new_dim) and not type(new_dim) == str: - str_dim = str(new_dim) - if str_dim in self.suggested_merge_: - if is_literal(self.suggested_merge_[str_dim]): - continue # no need to create dim for literals - new_sympy_shape[i] = self.symbolic_dims_[ - self.suggested_merge_[str_dim]] - else: - # add new_dim if it's a computational expression - if not str(new_dim) in self.symbolic_dims_: - self.symbolic_dims_[str(new_dim)] = new_dim - - def _onnx_infer_single_node(self, node): - # skip onnx shape inference for some ops, as they are handled in _infer_* - skip_infer = node.op_type in [ - 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', 'Attention', - 'BiasGelu', 'EmbedLayerNormalization', 'FastGelu', 'Gelu', - 'LayerNormalization', 'LongformerAttention', - 'SkipLayerNormalization', 'PythonOp' - ] - - if not skip_infer: - # Only pass initializers that satisfy the following condition: - # (1) Operator need value of some input for shape inference. - # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output. - # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. - # (3) The initializer is not in graph input. The means the node input is "constant" in inference. - initializers = [] - if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: - initializers = [ - self.initializers_[name] for name in node.input - if (name in self.initializers_ and name not in - self.graph_inputs_) - ] - - # run single node inference with self.known_vi_ shapes - tmp_graph = helper.make_graph( - [node], 'tmp', [self.known_vi_[i] for i in node.input if i], - [make_named_value_info(i) for i in node.output], initializers) - - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) - - for i_o in range(len(node.output)): - o = node.output[i_o] - vi = self.out_mp_.graph.value_info.add() - if not skip_infer: - vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) - else: - vi.name = o - self.known_vi_[o] = vi - - def _onnx_infer_subgraph(self, - node, - subgraph, - use_node_input=True, - inc_subgraph_id=True): - if self.verbose_ > 2: - logger.debug( - 'Inferencing subgraph of node {} with output({}...): {}'.format( - node.name, node.output[0], node.op_type)) - # node inputs are not passed directly to the subgraph - # it's up to the node dispatcher to prepare subgraph input - # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape - # besides, inputs in subgraph could shadow implicit inputs - subgraph_inputs = set( - [i.name for i in list(subgraph.initializer) + list(subgraph.input)]) - subgraph_implicit_input = set([ - name for name in self.known_vi_.keys() - if not name in subgraph_inputs - ]) - tmp_graph = helper.make_graph( - list(subgraph.node), 'tmp', - list(subgraph.input) + - [self.known_vi_[i] for i in subgraph_implicit_input], - [make_named_value_info(i.name) for i in subgraph.output]) - tmp_graph.initializer.extend([ - i for i in self.out_mp_.graph.initializer - if i.name in subgraph_implicit_input - ]) - tmp_graph.initializer.extend(subgraph.initializer) - self.tmp_mp_.graph.CopyFrom(tmp_graph) - - symbolic_shape_inference = SymbolicShapeInference( - self.int_max_, - self.auto_merge_, - self.guess_output_rank_, - self.verbose_, - prefix=self.prefix_ + '_' + str(self.subgraph_id_)) - if inc_subgraph_id: - self.subgraph_id_ += 1 - - all_shapes_inferred = False - symbolic_shape_inference._preprocess(self.tmp_mp_) - symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl( - self.sympy_data_.copy()) - symbolic_shape_inference._update_output_from_vi() - if use_node_input: - # if subgraph uses node input, it needs to update to merged dims - subgraph.ClearField('input') - subgraph.input.extend( - symbolic_shape_inference.out_mp_.graph.input[:len(node.input)]) - subgraph.ClearField('output') - subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) - subgraph.ClearField('value_info') - subgraph.value_info.extend( - symbolic_shape_inference.out_mp_.graph.value_info) - subgraph.ClearField('node') - subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) - # for new symbolic dims from subgraph output, add to main graph symbolic dims - subgraph_shapes = [ - get_shape_from_value_info(o) - for o in symbolic_shape_inference.out_mp_.graph.output - ] - subgraph_new_symbolic_dims = set([ - d for s in subgraph_shapes - if s for d in s if type(d) == str and not d in self.symbolic_dims_ - ]) - new_dims = {} - for d in subgraph_new_symbolic_dims: - assert d in symbolic_shape_inference.symbolic_dims_ - new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] - self.symbolic_dims_.update(new_dims) - return symbolic_shape_inference - - def _get_int_values(self, node, broadcast=False): - values = [self._try_get_value(node, i) for i in range(len(node.input))] - if all([v is not None for v in values]): - # some shape compute is in floating point, cast to int for sympy - for i, v in enumerate(values): - if type(v) != np.ndarray: - continue - if len(v.shape) > 1: - new_v = None # ignore value for rank > 1 - elif len(v.shape) == 0: - new_v = int(v.item()) - else: - assert len(v.shape) == 1 - new_v = [int(vv) for vv in v] - values[i] = new_v - values_len = [len(v) if type(v) == list else 0 for v in values] - max_len = max(values_len) - if max_len >= 1 and broadcast: - # broadcast - for i, v in enumerate(values): - if v is None: - continue # don't broadcast if value is unknown - if type(v) == list: - if len(v) < max_len: - values[i] = v * max_len - else: - assert len(v) == max_len - else: - values[i] = [v] * max_len - return values - - def _compute_on_sympy_data(self, node, op_func): - assert len(node.output) == 1 - values = self._get_int_values(node, broadcast=True) - if all([v is not None for v in values]): - is_list = [type(v) == list for v in values] - as_list = any(is_list) - if as_list: - self.sympy_data_[node.output[ - 0]] = [op_func(vs) for vs in zip(*values)] - else: - self.sympy_data_[node.output[0]] = op_func(values) - - def _pass_on_sympy_data(self, node): - assert len( - node. - input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze'] - self._compute_on_sympy_data(node, lambda x: x[0]) - - def _pass_on_shape_and_type(self, node): - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - self._get_shape(node, 0))) - - def _new_symbolic_dim(self, prefix, dim): - new_dim = '{}_d{}'.format(prefix, dim) - if new_dim in self.suggested_merge_: - v = self.suggested_merge_[new_dim] - new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v - else: - new_symbolic_dim = sympy.Symbol( - new_dim, integer=True, nonnegative=True) - self.symbolic_dims_[new_dim] = new_symbolic_dim - return new_symbolic_dim - - def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): - return self._new_symbolic_dim('{}{}_{}_o{}_'.format( - node.op_type, self.prefix_, - list(self.out_mp_.graph.node).index(node), out_idx), dim) - - def _new_symbolic_shape(self, rank, node, out_idx=0): - return [ - self._new_symbolic_dim_from_output(node, out_idx, i) - for i in range(rank) - ] - - def _compute_conv_pool_shape(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - if len(node.input) > 1: - W_shape = self._get_sympy_shape(node, 1) - rank = len(W_shape) - 2 # number of spatial axes - kernel_shape = W_shape[-rank:] - sympy_shape[1] = W_shape[0] - else: - W_shape = None - kernel_shape = get_attribute(node, 'kernel_shape') - rank = len(kernel_shape) - - assert len(sympy_shape) == rank + 2 - - # only need to symbolic shape inference if input has symbolic dims in spatial axes - is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]] - - if not any(is_symbolic_dims): - shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) - if len(shape) > 0: - assert len(sympy_shape) == len(shape) - sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] - return sympy_shape - - dilations = get_attribute(node, 'dilations', [1] * rank) - strides = get_attribute(node, 'strides', [1] * rank) - effective_kernel_shape = [(k - 1) * d + 1 - for k, d in zip(kernel_shape, dilations)] - pads = get_attribute(node, 'pads') - if pads is None: - pads = [0] * (2 * rank) - auto_pad = get_attribute(node, 'auto_pad', - b'NOTSET').decode('utf-8') - if auto_pad != 'VALID' and auto_pad != 'NOTSET': - try: - residual = [ - sympy.Mod(d, s) - for d, s in zip(sympy_shape[-rank:], strides) - ] - total_pads = [ - max(0, (k - s) if r == 0 else (k - r)) - for k, s, r in zip(effective_kernel_shape, strides, - residual) - ] - except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational - total_pads = [ - max(0, (k - s)) - for k, s in zip(effective_kernel_shape, strides) - ] # assuming no residual if sympy throws error - elif auto_pad == 'VALID': - total_pads = [] - else: - total_pads = [0] * rank - else: - assert len(pads) == 2 * rank - total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] - - ceil_mode = get_attribute(node, 'ceil_mode', 0) - for i in range(rank): - effective_input_size = sympy_shape[-rank + i] - if len(total_pads) > 0: - effective_input_size = effective_input_size + total_pads[i] - if ceil_mode: - strided_kernel_positions = sympy.ceiling( - (effective_input_size - effective_kernel_shape[i]) / - strides[i]) - else: - strided_kernel_positions = ( - effective_input_size - effective_kernel_shape[i] - ) // strides[i] - sympy_shape[-rank + i] = strided_kernel_positions + 1 - return sympy_shape - - def _check_merged_dims(self, dims, allow_broadcast=True): - if allow_broadcast: - dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] - if not all([d == dims[0] for d in dims]): - self._add_suggested_merge(dims, apply=True) - - def _compute_matmul_shape(self, node, output_dtype=None): - lhs_shape = self._get_shape(node, 0) - rhs_shape = self._get_shape(node, 1) - lhs_rank = len(lhs_shape) - rhs_rank = len(rhs_shape) - lhs_reduce_dim = 0 - rhs_reduce_dim = 0 - assert lhs_rank > 0 and rhs_rank > 0 - if lhs_rank == 1 and rhs_rank == 1: - new_shape = [] - elif lhs_rank == 1: - rhs_reduce_dim = -2 - new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] - elif rhs_rank == 1: - lhs_reduce_dim = -1 - new_shape = lhs_shape[:lhs_reduce_dim] - else: - lhs_reduce_dim = -1 - rhs_reduce_dim = -2 - new_shape = self._broadcast_shapes( - lhs_shape[:-2], - rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]] - # merge reduce dim - self._check_merged_dims( - [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], - allow_broadcast=False) - if output_dtype is None: - # infer output_dtype from input type when not specified - output_dtype = self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, - new_shape)) - - def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - ''' - update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches - ''' - dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence( - dst_type) else dst_type.tensor_type - src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence( - src_type) else src_type.tensor_type - if dst_tensor_type.elem_type != src_tensor_type.elem_type: - node_id = node.name if node.name else node.op_type - raise ValueError( - f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " - f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " - f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" - ) - if dst_tensor_type.HasField('shape'): - for di, ds in enumerate( - zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): - if ds[0] != ds[1]: - # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type - # for sequence_type, clear the dimension - new_dim = onnx.TensorShapeProto.Dimension() - if not is_sequence(dst_type): - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, out_idx, - di)) - dst_tensor_type.shape.dim[di].CopyFrom(new_dim) - else: - dst_tensor_type.CopyFrom(src_tensor_type) - - def _infer_ArrayFeatureExtractor(self, node): - data_shape = self._get_shape(node, 0) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape[:-1] + - indices_shape)) - - def _infer_symbolic_compute_ops(self, node): - funcs = { - 'Add': - lambda l: l[0] + l[1], - 'Div': - lambda l: l[0] // l[1], # integer div in sympy - 'Equal': - lambda l: l[0] == l[1], - 'Floor': - lambda l: sympy.floor(l[0]), - 'Max': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), - 'Min': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), - 'Mul': - lambda l: l[0] * l[1], - 'Sub': - lambda l: l[0] - l[1], - 'Where': - lambda l: l[1] if l[0] else l[2], - 'Neg': - lambda l: -l[0] - } - assert node.op_type in funcs - self._compute_on_sympy_data(node, funcs[node.op_type]) - - def _infer_Cast(self, node): - self._pass_on_sympy_data(node) - - def _infer_CategoryMapper(self, node): - input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if input_type == onnx.TensorProto.STRING: - output_type = onnx.TensorProto.INT64 - else: - output_type = onnx.TensorProto.STRING - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_type, - self._get_shape(node, 0))) - - def _infer_Compress(self, node): - input_shape = self._get_shape(node, 0) - # create a new symbolic dimension for Compress output - compress_len = str(self._new_symbolic_dim_from_output(node)) - axis = get_attribute(node, 'axis') - if axis == None: - # when axis is not specified, input is flattened before compress so output is 1D - output_shape = [compress_len] - else: - output_shape = input_shape - output_shape[handle_negative_axis(axis, len( - input_shape))] = compress_len - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - - def _infer_Concat(self, node): - if any([ - i in self.sympy_data_ or i in self.initializers_ - for i in node.input - ]): - values = self._get_int_values(node) - print("=======", values, node.name, get_attribute(node, 'axis')) - if all([v is not None for v in values]): - axis = get_attribute(node, 'axis') - if axis < 0: - axis = axis + len(values[0]) - assert 0 == axis - self.sympy_data_[node.output[0]] = [] - for i in range(len(node.input)): - value = values[i] - if type(value) == list: - self.sympy_data_[node.output[0]].extend(value) - else: - self.sympy_data_[node.output[0]].append(value) - - sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis'), len(sympy_shape)) - for i_idx in range(1, len(node.input)): - input_shape = self._get_sympy_shape(node, i_idx) - if input_shape: - sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] - self._update_computed_dims(sympy_shape) - # merge symbolic dims for non-concat axes - for d in range(len(sympy_shape)): - if d == axis: - continue - dims = [ - self._get_shape(node, i_idx)[d] - for i_idx in range(len(node.input)) - if self._get_shape(node, i_idx) - ] - if all([d == dims[0] for d in dims]): - continue - merged = self._merge_symbols(dims) - if type(merged) == str: - sympy_shape[d] = self.symbolic_dims_[merged] if merged else None - else: - sympy_shape[d] = merged - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.tensor_type. - elem_type, get_shape_from_sympy_shape(sympy_shape))) - - def _infer_ConcatFromSequence(self, node): - seq_shape = self._get_shape(node, 0) - new_axis = 1 if get_attribute(node, 'new_axis') else 0 - axis = handle_negative_axis( - get_attribute(node, 'axis'), len(seq_shape) + new_axis) - concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) - new_shape = seq_shape - if new_axis: - new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] - else: - new_shape[axis] = concat_dim - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]] - .type.sequence_type.elem_type.tensor_type.elem_type, new_shape)) - - def _infer_Constant(self, node): - t = get_attribute(node, 'value') - self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) - - def _infer_ConstantOfShape(self, node): - sympy_shape = self._get_int_values(node)[0] - vi = self.known_vi_[node.output[0]] - if sympy_shape is not None: - if type(sympy_shape) != list: - sympy_shape = [sympy_shape] - self._update_computed_dims(sympy_shape) - # update sympy data if output type is int, and shape is known - if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( - [is_literal(x) for x in sympy_shape]): - self.sympy_data_[node.output[0]] = np.ones( - [int(x) for x in sympy_shape], - dtype=np.int64) * numpy_helper.to_array( - get_attribute(node, 'value', 0)) - else: - # create new dynamic shape - # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length - sympy_shape = self._new_symbolic_shape( - self._get_shape(node, 0)[0], node) - - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) - - def _infer_Conv(self, node): - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) - - def _infer_Einsum(self, node): - # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 - equation = get_attribute(node, 'equation') - equation = equation.replace(b' ', b'') - mid_index = equation.find(b'->') - left_equation = equation[:mid_index] if mid_index != -1 else equation - - num_operands = 0 - num_ellipsis = 0 - num_ellipsis_indices = 0 - - letter_to_dim = {} - - terms = left_equation.split(b',') - for term in terms: - ellipsis_index = term.find(b'...') - shape = self._get_shape(node, num_operands) - rank = len(shape) - if ellipsis_index != -1: - if num_ellipsis == 0: - num_ellipsis_indices = rank - len(term) + 3 - num_ellipsis = num_ellipsis + 1 - for i in range(1, rank + 1): - letter = term[-i] - if letter != 46: # letter != b'.' - dim = shape[-i] - if letter not in letter_to_dim.keys(): - letter_to_dim[letter] = dim - elif type(dim) != sympy.Symbol: - letter_to_dim[letter] = dim - num_operands = num_operands + 1 - - new_sympy_shape = [] - from collections import OrderedDict - num_letter_occurrences = OrderedDict() - if mid_index != -1: - right_equation = equation[mid_index + 2:] - right_ellipsis_index = right_equation.find(b'...') - if right_ellipsis_index != -1: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in right_equation: - if c != 46: # c != b'.' - new_sympy_shape.append(letter_to_dim[c]) - else: - for i in range(num_ellipsis_indices): - new_sympy_shape.append(shape[i]) - for c in left_equation: - if c != 44 and c != 46: # c != b',' and c != b'.': - if c in num_letter_occurrences: - num_letter_occurrences[c] = num_letter_occurrences[ - c] + 1 - else: - num_letter_occurrences[c] = 1 - for key, value in num_letter_occurrences.items(): - if value == 1: - new_sympy_shape.append(letter_to_dim[key]) - - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, - new_sympy_shape)) - - def _infer_Expand(self, node): - expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) - if expand_to_shape is not None: - # new_shape's dim can come from shape value - self._update_computed_dims(expand_to_shape) - shape = self._get_shape(node, 0) - new_shape = self._broadcast_shapes( - shape, get_shape_from_sympy_shape(expand_to_shape)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, new_shape)) - - def _infer_Gather(self, node): - data_shape = self._get_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis', 0), len(data_shape)) - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape[:axis] + - indices_shape + data_shape[axis + - 1:])) - # for 1D input, do some sympy compute - if node.input[0] in self.sympy_data_ and len( - data_shape) == 1 and 0 == get_attribute(node, 'axis', 0): - idx = self._try_get_value(node, 1) - if idx is not None: - data = self.sympy_data_[node.input[0]] - if type(data) == list: - if type(idx) == np.ndarray and len(idx.shape) == 1: - self.sympy_data_[node.output[ - 0]] = [data[int(i)] for i in idx] - else: - self.sympy_data_[node.output[0]] = data[int(idx)] - else: - assert idx == 0 or idx == -1 - self.sympy_data_[node.output[0]] = data - - def _infer_GatherElements(self, node): - indices_shape = self._get_shape(node, 1) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, indices_shape)) - - def _infer_GatherND(self, node): - data_shape = self._get_shape(node, 0) - data_rank = len(data_shape) - indices_shape = self._get_shape(node, 1) - indices_rank = len(indices_shape) - last_index_dimension = indices_shape[-1] - assert is_literal( - last_index_dimension) and last_index_dimension <= data_rank - new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, new_shape)) - - def _infer_If(self, node): - # special case for constant condition, in case there are mismatching shape from the non-executed branch - subgraphs = [ - get_attribute(node, 'then_branch'), get_attribute(node, - 'else_branch') - ] - cond = self._try_get_value(node, 0) - if cond is not None: - if as_scalar(cond) > 0: - subgraphs[1].CopyFrom(subgraphs[0]) - else: - subgraphs[0].CopyFrom(subgraphs[1]) - - for i_sub, subgraph in enumerate(subgraphs): - subgraph_infer = self._onnx_infer_subgraph( - node, subgraph, use_node_input=False) - for i_out in range(len(node.output)): - vi = self.known_vi_[node.output[i_out]] - if i_sub == 0: - vi.CopyFrom(subgraph.output[i_out]) - vi.name = node.output[i_out] - else: - self._fuse_tensor_type(node, i_out, vi.type, - subgraph.output[i_out].type) - - # pass on sympy data from subgraph, if cond is constant - if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else - 1): - if subgraph.output[ - i_out].name in subgraph_infer.sympy_data_: - self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ - subgraph.output[i_out].name] - - def _infer_Loop(self, node): - subgraph = get_attribute(node, 'body') - assert len(subgraph.input) == len(node.input) - num_loop_carried = len( - node.input) - 2 # minus the length and initial loop condition - # when sequence_type is used as loop carried input - # needs to run subgraph infer twice if the tensor shape in sequence contains None - for i, si in enumerate(subgraph.input): - si_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - si.name = si_name - - self._onnx_infer_subgraph(node, subgraph) - - # check subgraph input/output for shape changes in loop carried variables - # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a) - # for sequence_type, propagate from output to input - need_second_infer = False - for i_out in range(1, num_loop_carried + 1): - so = subgraph.output[i_out] - so_shape = get_shape_from_value_info(so) - if is_sequence(so.type): - if so_shape and None in so_shape: - # copy shape from output to input - # note that loop input is [loop_len, cond, input_0, input_1, ...] - # while loop output is [cond, output_0, output_1, ...] - subgraph.input[i_out + - 1].type.sequence_type.elem_type.CopyFrom( - so.type.sequence_type.elem_type) - need_second_infer = True - else: - si = subgraph.input[i_out + 1] - si_shape = get_shape_from_value_info(si) - for di, dims in enumerate(zip(si_shape, so_shape)): - if dims[0] != dims[1]: - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, i_out, di)) - si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - need_second_infer = True - - if need_second_infer: - if self.verbose_ > 2: - logger.debug( - "Rerun Loop: {}({}...), because of sequence in loop carried variables". - format(node.name, node.output[0])) - self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) - - # create a new symbolic dimension for iteration dependent dimension - loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) - for i in range(len(node.output)): - vi = self.known_vi_[node.output[i]] - vi.CopyFrom(subgraph.output[ - i + - 1]) # first subgraph output is condition, not in node output - if i >= num_loop_carried: - assert not is_sequence( - vi.type) # TODO: handle loop accumulation in sequence_type - subgraph_vi_dim = subgraph.output[i + - 1].type.tensor_type.shape.dim - vi.type.tensor_type.shape.ClearField('dim') - vi_dim = vi.type.tensor_type.shape.dim - vi_dim.add().dim_param = loop_iter_dim - vi_dim.extend(list(subgraph_vi_dim)) - vi.name = node.output[i] - - def _infer_MatMul(self, node): - self._compute_matmul_shape(node) - - def _infer_MatMulInteger(self, node): - self._compute_matmul_shape(node, onnx.TensorProto.INT32) - - def _infer_NonMaxSuppression(self, node): - selected = str(self._new_symbolic_dim_from_output(node)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], onnx.TensorProto.INT64, [selected, 3])) - - def _infer_NonZero(self, node): - input_rank = self._get_shape_rank(node, 0) - # create a new symbolic dimension for NonZero output - nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) - - def _infer_OneHot(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - depth = self._try_get_value(node, 1) - axis = get_attribute(node, 'axis', -1) - axis = handle_negative_axis(axis, len(sympy_shape) + 1) - new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [ - self._new_symbolic_dim_from_output(node) - if not is_literal(depth) else depth - ] + sympy_shape[axis:]) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[2]].type.tensor_type.elem_type, new_shape)) - - def _infer_Pad(self, node): - if get_opset(self.out_mp_) <= 10: - pads = get_attribute(node, 'pads') - else: - pads = self._try_get_value(node, 1) - - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - - if pads is not None: - assert len(pads) == 2 * rank - new_sympy_shape = [ - d + pad_up + pad_down - for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[ - rank:]) - ] - self._update_computed_dims(new_sympy_shape) - else: - # dynamic pads, create new symbolic dimensions - new_sympy_shape = self._new_symbolic_shape(rank, node) - output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_Pool(self, node): - sympy_shape = self._compute_conv_pool_shape(node) - self._update_computed_dims(sympy_shape) - for o in node.output: - if not o: - continue - vi = self.known_vi_[o] - vi.CopyFrom( - helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - sympy_shape))) - - def _infer_aten_bitwise_or(self, node): - shape0 = self._get_shape(node, 0) - shape1 = self._get_shape(node, 1) - new_shape = self._broadcast_shapes(shape0, shape1) - t0 = self.known_vi_[node.input[0]] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], t0.type.tensor_type.elem_type, new_shape)) - - def _infer_aten_diagonal(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - offset = self._try_get_value(node, 1) - dim1 = self._try_get_value(node, 2) - dim2 = self._try_get_value(node, 3) - - assert offset is not None and dim1 is not None and dim2 is not None - dim1 = handle_negative_axis(dim1, rank) - dim2 = handle_negative_axis(dim2, rank) - - new_shape = [] - for dim, val in enumerate(sympy_shape): - if dim not in [dim1, dim2]: - new_shape.append(val) - - shape1 = sympy_shape[dim1] - shape2 = sympy_shape[dim2] - if offset >= 0: - diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) - else: - diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) - new_shape.append(diag_shape) - - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - new_shape))) - - def _infer_aten_multinomial(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - rank = len(sympy_shape) - assert rank in [1, 2] - num_samples = self._try_get_value(node, 1) - di = rank - 1 - last_dim = num_samples if num_samples else str( - self._new_symbolic_dim_from_output(node, 0, di)) - output_shape = sympy_shape[:-1] + [last_dim] - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, - get_shape_from_sympy_shape(output_shape))) - - def _infer_aten_pool2d(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - assert len(sympy_shape) == 4 - sympy_shape[-2:] = [ - self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3] - ] - self._update_computed_dims(sympy_shape) - for i, o in enumerate(node.output): - if not o: - continue - vi = self.known_vi_[o] - elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[ - node.input[0]].type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info( - o, elem_type, get_shape_from_sympy_shape(sympy_shape))) - - def _infer_aten_unfold(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - dimension = self._try_get_value(node, 1) - size = self._try_get_value(node, 2) - step = self._try_get_value(node, 3) - if dimension is not None and size is not None and step is not None: - assert dimension < len(sympy_shape) - sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 - sympy_shape.append(size) - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank + 1, node) - self._update_computed_dims(sympy_shape) - if node.output[0]: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - sympy_shape))) - - def _infer_aten_argmax(self, node): - new_shape = None - if node.input[1] == '': - # The argmax of the flattened input is returned. - new_shape = [] - else: - dim = self._try_get_value(node, 1) - keepdim = self._try_get_value(node, 2) - if keepdim is not None: - sympy_shape = self._get_sympy_shape(node, 0) - if dim is not None: - dim = handle_negative_axis(dim, len(sympy_shape)) - if keepdim: - sympy_shape[dim] = 1 - else: - del sympy_shape[dim] - else: - rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape(rank if keepdim else - rank - 1, node) - self._update_computed_dims(sympy_shape) - new_shape = get_shape_from_sympy_shape(sympy_shape) - if node.output[0] and new_shape is not None: - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], onnx.TensorProto.INT64, new_shape)) - - def _infer_aten_bce(self, node): - reduction = self._try_get_value(node, 4) - if reduction is None: - reduction = 1 - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - if reduction == 0: - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - else: - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, elem_type, - self._get_shape(node, 0))) - - def _infer_BatchNormalization(self, node): - self._propagate_shape_and_type(node) - - # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop - for i in [1, 2, 3, 4]: - if i < len(node.output) and node.output[i] != "": - # all of these parameters have the same shape as the 1st input - self._propagate_shape_and_type( - node, input_index=1, output_index=i) - - def _infer_Range(self, node): - vi = self.known_vi_[node.output[0]] - input_data = self._get_int_values(node) - if all([i is not None for i in input_data]): - start = as_scalar(input_data[0]) - limit = as_scalar(input_data[1]) - delta = as_scalar(input_data[2]) - new_sympy_shape = [ - sympy.Max(sympy.ceiling((limit - start) / delta), 0) - ] - else: - new_sympy_shape = [self._new_symbolic_dim_from_output(node)] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.tensor_type. - elem_type, get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_ReduceSum(self, node): - keep_dims = get_attribute(node, 'keepdims', 1) - if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: - # ReduceSum changes axes to input[1] in opset 13 - axes = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if axes is None: - assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape( - self._get_shape_rank(node, 0), node)))) - else: - shape = self._get_shape(node, 0) - output_shape = [] - axes = [handle_negative_axis(a, len(shape)) for a in axes] - for i, d in enumerate(shape): - if i in axes: - if keep_dims: - output_shape.append(1) - else: - output_shape.append(d) - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, output_shape)) - - def _infer_ReduceProd(self, node): - axes = get_attribute(node, 'axes') - keep_dims = get_attribute(node, 'keepdims', 1) - if keep_dims == 0 and axes == [0]: - data = self._get_int_values(node)[0] - if data is not None: - self.sympy_data_[node.output[0]] = sympy_reduce_product(data) - - def _infer_Reshape(self, node): - shape_value = self._try_get_value(node, 1) - vi = self.known_vi_[node.output[0]] - if shape_value is None: - shape_shape = self._get_shape(node, 1) - assert len(shape_shape) == 1 - shape_rank = shape_shape[0] - assert is_literal(shape_rank) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape(shape_rank, node)))) - else: - input_sympy_shape = self._get_sympy_shape(node, 0) - total = int(1) - for d in input_sympy_shape: - total = total * d - new_sympy_shape = [] - deferred_dim_idx = -1 - non_deferred_size = int(1) - for i, d in enumerate(shape_value): - if type(d) == sympy.Symbol: - new_sympy_shape.append(d) - elif d == 0: - new_sympy_shape.append(input_sympy_shape[i]) - non_deferred_size = non_deferred_size * input_sympy_shape[i] - else: - new_sympy_shape.append(d) - if d == -1: - deferred_dim_idx = i - elif d != 0: - non_deferred_size = non_deferred_size * d - - assert new_sympy_shape.count(-1) < 2 - if -1 in new_sympy_shape: - new_dim = total // non_deferred_size - new_sympy_shape[deferred_dim_idx] = new_dim - - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - self._pass_on_sympy_data(node) - - def _infer_Resize(self, node): - vi = self.known_vi_[node.output[0]] - input_sympy_shape = self._get_sympy_shape(node, 0) - if get_opset(self.out_mp_) <= 10: - scales = self._try_get_value(node, 1) - if scales is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * s)) - for d, s in zip(input_sympy_shape, scales) - ] - self._update_computed_dims(new_sympy_shape) - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - else: - roi = self._try_get_value(node, 1) - scales = self._try_get_value(node, 2) - sizes = self._try_get_value(node, 3) - if sizes is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(s)) for s in sizes - ] - self._update_computed_dims(new_sympy_shape) - elif scales is not None: - rank = len(scales) - if get_attribute(node, 'coordinate_transformation_mode' - ) == 'tf_crop_and_resize': - assert len(roi) == 2 * rank - roi_start = list(roi)[:rank] - roi_end = list(roi)[rank:] - else: - roi_start = [0] * rank - roi_end = [1] * rank - scales = list(scales) - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip(input_sympy_shape, - roi_start, roi_end, scales) - ] - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node) - - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - new_sympy_shape))) - - def _infer_Scan(self, node): - subgraph = get_attribute(node, 'body') - num_scan_inputs = get_attribute(node, 'num_scan_inputs') - scan_input_axes = get_attribute(node, 'scan_input_axes', - [0] * num_scan_inputs) - num_scan_states = len(node.input) - num_scan_inputs - scan_input_axes = [ - handle_negative_axis( - ax, self._get_shape_rank(node, i + num_scan_states)) - for i, ax in enumerate(scan_input_axes) - ] - # We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer, - # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. - assert len(subgraph.input) >= len(node.input) - subgraph_inputs = subgraph.input[:len(node.input)] - for i, si in enumerate(subgraph_inputs): - subgraph_name = si.name - si.CopyFrom(self.known_vi_[node.input[i]]) - if i >= num_scan_states: - scan_input_dim = si.type.tensor_type.shape.dim - scan_input_dim.remove( - scan_input_dim[scan_input_axes[i - num_scan_states]]) - si.name = subgraph_name - self._onnx_infer_subgraph(node, subgraph) - num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute(node, 'scan_output_axes', - [0] * num_scan_outputs) - scan_input_dim = get_shape_from_type_proto( - self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] - for i, o in enumerate(node.output): - vi = self.known_vi_[o] - if i >= num_scan_states: - shape = get_shape_from_type_proto(subgraph.output[i].type) - new_dim = handle_negative_axis( - scan_output_axes[i - num_scan_states], len(shape) + 1) - shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] - vi.CopyFrom( - helper.make_tensor_value_info(o, subgraph.output[ - i].type.tensor_type.elem_type, shape)) - else: - vi.CopyFrom(subgraph.output[i]) - vi.name = o - - def _infer_ScatterElements(self, node): - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, data_shape)) - - def _infer_SequenceAt(self, node): - # need to create new symbolic dimension if sequence shape has None: - seq_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[0]] - if seq_shape is not None: - for di, d in enumerate(seq_shape): - if d is not None: - continue - new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, 0, di)) - vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) - - def _infer_SequenceInsert(self, node): - # workaround bug in onnx's shape inference - vi_seq = self.known_vi_[node.input[0]] - vi_tensor = self.known_vi_[node.input[1]] - vi_out_seq = self.known_vi_[node.output[0]] - vi_out_seq.CopyFrom(vi_seq) - vi_out_seq.name = node.output[0] - self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) - - def _infer_Shape(self, node): - self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) - - def _infer_Size(self, node): - sympy_shape = self._get_sympy_shape(node, 0) - self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) - self.known_vi_[node.output[0]].CopyFrom( - helper.make_tensor_value_info(node.output[0], - onnx.TensorProto.INT64, [])) - - def _infer_Slice(self, node): - def less_equal(x, y): - try: - return bool(x <= y) - except TypeError: - pass - try: - return bool(y >= x) - except TypeError: - pass - try: - return bool(-x >= -y) - except TypeError: - pass - try: - return bool(-y <= -x) - except TypeError: - # the last attempt; this may raise TypeError - return bool(y - x >= 0) - - def handle_negative_index(index, bound): - """ normalizes a negative index to be in [0, bound) """ - try: - if not less_equal(0, index): - if is_literal(index) and index <= -self.int_max_: - # this case is handled separately - return index - return bound + index - except TypeError: - logger.warning("Cannot determine if {} < 0".format(index)) - return index - - if get_opset(self.out_mp_) <= 9: - axes = get_attribute(node, 'axes') - starts = get_attribute(node, 'starts') - ends = get_attribute(node, 'ends') - if not axes: - axes = list(range(len(starts))) - steps = [1] * len(axes) - else: - starts = as_list(self._try_get_value(node, 1), keep_none=True) - ends = as_list(self._try_get_value(node, 2), keep_none=True) - axes = self._try_get_value(node, 3) - steps = self._try_get_value(node, 4) - if axes is None and not (starts is None and ends is None): - axes = list( - range(0, len(starts if starts is not None else ends))) - if steps is None and not (starts is None and ends is None): - steps = [1] * len(starts if starts is not None else ends) - axes = as_list(axes, keep_none=True) - steps = as_list(steps, keep_none=True) - - new_sympy_shape = self._get_sympy_shape(node, 0) - if starts is None or ends is None: - if axes is None: - for i in range(len(new_sympy_shape)): - new_sympy_shape[i] = self._new_symbolic_dim_from_output( - node, 0, i) - else: - new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) - for i in axes: - new_sympy_shape[i] = self._new_symbolic_dim_from_output( - node, 0, i) - else: - for i, s, e, t in zip(axes, starts, ends, steps): - e = handle_negative_index(e, new_sympy_shape[i]) - if is_literal(e): - if e >= self.int_max_: - e = new_sympy_shape[i] - elif e <= -self.int_max_: - e = 0 if s > 0 else -1 - elif is_literal(new_sympy_shape[i]): - if e < 0: - e = max(0, e + new_sympy_shape[i]) - e = min(e, new_sympy_shape[i]) - else: - if e > 0: - e = sympy.Min( - e, new_sympy_shape[i] - ) if e > 1 else e #special case for slicing first to make computation easier - else: - if is_literal(new_sympy_shape[i]): - e = sympy.Min(e, new_sympy_shape[i]) - else: - try: - if not less_equal(e, new_sympy_shape[i]): - e = new_sympy_shape[i] - except Exception: - logger.warning( - 'Unable to determine if {} <= {}, treat as equal'. - format(e, new_sympy_shape[i])) - e = new_sympy_shape[i] - - s = handle_negative_index(s, new_sympy_shape[i]) - if is_literal(new_sympy_shape[i]) and is_literal(s): - s = max(0, min(s, new_sympy_shape[i])) - - new_sympy_shape[i] = sympy.simplify( - (e - s + t + (-1 if t > 0 else 1)) // t) - - self._update_computed_dims(new_sympy_shape) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - # handle sympy_data if needed, for slice in shape computation - if (node.input[0] in self.sympy_data_ and [0] == axes and - len(starts) == 1 and len(ends) == 1 and len(steps) == 1): - input_sympy_data = self.sympy_data_[node.input[0]] - if type(input_sympy_data) == list or ( - type(input_sympy_data) == np.array and - len(input_sympy_data.shape) == 1): - self.sympy_data_[node.output[0]] = input_sympy_data[starts[ - 0]:ends[0]:steps[0]] - - def _infer_SoftmaxCrossEntropyLoss(self, node): - vi = self.known_vi_[node.output[0]] - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - - if len(node.output) > 1: - data_shape = self._get_shape(node, 0) - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, elem_type, data_shape)) - - def _infer_Split_Common(self, node, make_value_info_func): - input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, 'axis', 0), len(input_sympy_shape)) - split = get_attribute(node, 'split') - if not split: - num_outputs = len(node.output) - split = [input_sympy_shape[axis] / - sympy.Integer(num_outputs)] * num_outputs - self._update_computed_dims(split) - else: - split = [sympy.Integer(s) for s in split] - - for i_o in range(len(split)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - make_value_info_func(node.output[i_o], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - input_sympy_shape[:axis] + [ - split[i_o] - ] + input_sympy_shape[axis + 1:]))) - self.known_vi_[vi.name] = vi - - def _infer_Split(self, node): - self._infer_Split_Common(node, helper.make_tensor_value_info) - - def _infer_SplitToSequence(self, node): - self._infer_Split_Common(node, helper.make_sequence_value_info) - - def _infer_Squeeze(self, node): - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, 'axes') - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None - - if axes is None: - # No axes have been provided (neither via attribute nor via input). - # In this case the 'Shape' op should remove all axis with dimension 1. - # For symbolic dimensions we guess they are !=1. - output_shape = [s for s in input_shape if s != 1] - if self.verbose_ > 0: - symbolic_dimensions = [s for s in input_shape if type(s) != int] - if len(symbolic_dimensions) > 0: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - + - f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" - ) - else: - axes = [handle_negative_axis(a, len(input_shape)) for a in axes] - output_shape = [] - for i in range(len(input_shape)): - if i not in axes: - output_shape.append(input_shape[i]) - else: - assert input_shape[i] == 1 or type(input_shape[i]) != int - if self.verbose_ > 0 and type(input_shape[i]) != int: - logger.debug( - f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " - + - f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." - ) - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - self._pass_on_sympy_data(node) - - def _infer_Tile(self, node): - repeats_value = self._try_get_value(node, 1) - new_sympy_shape = [] - if repeats_value is not None: - input_sympy_shape = self._get_sympy_shape(node, 0) - for i, d in enumerate(input_sympy_shape): - new_dim = d * repeats_value[i] - new_sympy_shape.append(new_dim) - self._update_computed_dims(new_sympy_shape) - else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node) - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) - - def _infer_TopK(self, node): - rank = self._get_shape_rank(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank) - new_shape = self._get_shape(node, 0) - - if get_opset(self.out_mp_) <= 9: - k = get_attribute(node, 'k') - else: - k = self._get_int_values(node)[1] - - if k == None: - k = self._new_symbolic_dim_from_output(node) - else: - k = as_scalar(k) - - if type(k) in [int, str]: - new_shape[axis] = k - else: - new_sympy_shape = self._get_sympy_shape(node, 0) - new_sympy_shape[axis] = k - self._update_computed_dims( - new_sympy_shape - ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape - new_shape = get_shape_from_sympy_shape(new_sympy_shape) - - for i_o in range(len(node.output)): - vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - i_o], vi.type.tensor_type.elem_type, new_shape)) - - def _infer_Transpose(self, node): - if node.input[0] in self.sympy_data_: - data_shape = self._get_shape(node, 0) - perm = get_attribute(node, 'perm', - reversed(list(range(len(data_shape))))) - input_data = self.sympy_data_[node.input[0]] - self.sympy_data_[node.output[0]] = np.transpose( - np.array(input_data).reshape(*data_shape), - axes=tuple(perm)).flatten().tolist() - - def _infer_Unsqueeze(self, node): - input_shape = self._get_shape(node, 0) - op_set = get_opset(self.out_mp_) - - # Depending on op-version 'axes' are provided as attribute or via 2nd input - if op_set < 13: - axes = get_attribute(node, 'axes') - assert self._try_get_value(node, 1) is None - else: - axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None - - output_rank = len(input_shape) + len(axes) - axes = [handle_negative_axis(a, output_rank) for a in axes] - - input_axis = 0 - output_shape = [] - for i in range(output_rank): - if i in axes: - output_shape.append(1) - else: - output_shape.append(input_shape[input_axis]) - input_axis += 1 - - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[ - node.input[0]].type.tensor_type.elem_type, output_shape)) - - self._pass_on_sympy_data(node) - - def _infer_ZipMap(self, node): - map_key_type = None - if get_attribute(node, 'classlabels_int64s') is not None: - map_key_type = onnx.TensorProto.INT64 - elif get_attribute(node, 'classlabels_strings') is not None: - map_key_type = onnx.TensorProto.STRING - - assert map_key_type is not None - new_vi = onnx.ValueInfoProto() - new_vi.name = node.output[0] - new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT - new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(new_vi) - - def _infer_Attention(self, node): - shape = self._get_shape(node, 0) - shape_bias = self._get_shape(node, 2) - assert len(shape) == 3 and len(shape_bias) == 1 - qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes') - if qkv_hidden_sizes_attr is not None: - assert len(qkv_hidden_sizes_attr) == 3 - shape[2] = int(qkv_hidden_sizes_attr[2]) - else: - shape[2] = int(shape_bias[0] / 3) - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape)) - - if len(node.output) > 1: - # input shape: (batch_size, sequence_length, hidden_size) - # past shape: (2, batch_size, num_heads, past_sequence_length, head_size) - # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) - # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length - input_shape = self._get_shape(node, 0) - past_shape = self._get_shape(node, 4) - mask_shape = self._get_shape(node, 3) - if len(past_shape) == 5: - if len(mask_shape) in [2, 3]: - past_shape[3] = mask_shape[-1] - elif isinstance(input_shape[1], int) and isinstance( - past_shape[3], int): - past_shape[3] = input_shape[1] + past_shape[3] - else: - past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, - past_shape)) - - def _infer_BiasGelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_FastGelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_Gelu(self, node): - self._propagate_shape_and_type(node) - - def _infer_LayerNormalization(self, node): - self._propagate_shape_and_type(node) - - def _infer_LongformerAttention(self, node): - self._propagate_shape_and_type(node) - - def _infer_EmbedLayerNormalization(self, node): - input_ids_shape = self._get_shape(node, 0) - word_embedding_shape = self._get_shape(node, 2) - assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 - output_shape = input_ids_shape + [word_embedding_shape[1]] - - word_embedding_dtype = self.known_vi_[node.input[ - 2]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], word_embedding_dtype, - output_shape)) - - mask_index_shape = [input_ids_shape[0]] - vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 1], onnx.TensorProto.INT32, mask_index_shape)) - - if len(node.output) > 2: - # Optional output of add before layer nomalization is done - # shape is same as the output - vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[ - 2], word_embedding_dtype, output_shape)) - - def _infer_SkipLayerNormalization(self, node): - self._propagate_shape_and_type(node) - - def _infer_PythonOp(self, node): - output_tensor_types = get_attribute(node, 'output_tensor_types') - assert output_tensor_types - output_tensor_ranks = get_attribute(node, 'output_tensor_ranks') - assert output_tensor_ranks - - # set the context output seperately. - # The first output is autograd's context. - vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], - onnx.TensorProto.INT64, [])) - - # Outputs after autograd's context are tensors. - # We assume their ranks are fixed for different model inputs. - for i in range(len(node.output) - 1): - # Process the i-th tensor outputs. - vi = self.known_vi_[node.output[i + 1]] - sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) - shape = get_shape_from_sympy_shape(sympy_shape) - value_info = helper.make_tensor_value_info( - node.output[i + 1], output_tensor_types[i], shape) - vi.CopyFrom(value_info) - - def _propagate_shape_and_type(self, node, input_index=0, output_index=0): - shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[node.input[ - input_index]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[output_index], - output_dtype, shape)) - - def _is_none_dim(self, dim_value): - if type(dim_value) != str: - return False - if "unk__" not in dim_value: - return False - if dim_value in self.symbolic_dims_.keys(): - return False - return True - - def _is_shape_contains_none_dim(self, out_shape): - for out in out_shape: - if self._is_none_dim(out): - return out - return None - - def _infer_impl(self, start_sympy_data=None): - self.sympy_data_ = start_sympy_data or {} - self.out_mp_.graph.ClearField('value_info') - self._apply_suggested_merge(graph_input_only=True) - self.input_symbols_ = set() - for i in self.out_mp_.graph.input: - input_shape = get_shape_from_value_info(i) - if input_shape is None: - continue - - if is_sequence(i.type): - input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim - else: - input_dims = i.type.tensor_type.shape.dim - - for i_dim, dim in enumerate(input_shape): - if dim is None: - # some models use None for symbolic dim in input, replace it with a string - input_dims[i_dim].dim_param = str( - self._new_symbolic_dim(i.name, i_dim)) - - self.input_symbols_.update( - [d for d in input_shape if type(d) == str]) - - for s in self.input_symbols_: - if s in self.suggested_merge_: - s_merge = self.suggested_merge_[s] - assert s_merge in self.symbolic_dims_ - self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] - else: - # Since inputs are not produced by other ops, we can assume positivity - self.symbolic_dims_[s] = sympy.Symbol( - s, integer=True, positive=True) - # create a temporary ModelProto for single node inference - # note that we remove initializer to have faster inference - # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways - self.tmp_mp_ = onnx.ModelProto() - self.tmp_mp_.CopyFrom(self.out_mp_) - self.tmp_mp_.graph.ClearField('initializer') - - # compute prerequesite for node for topological sort - # node with subgraphs may have dependency on implicit inputs, which will affect topological sort - prereq_for_node = { - } # map from node to all its inputs, including implicit ones in subgraph - - def get_prereq(node): - names = set(i for i in node.input if i) - subgraphs = [] - if 'If' == node.op_type: - subgraphs = [ - get_attribute(node, 'then_branch'), - get_attribute(node, 'else_branch') - ] - elif node.op_type in ['Loop', 'Scan']: - subgraphs = [get_attribute(node, 'body')] - for g in subgraphs: - g_outputs_and_initializers = {i.name for i in g.initializer} - g_prereq = set() - for n in g.node: - g_outputs_and_initializers.update(n.output) - for n in g.node: - g_prereq.update([ - i for i in get_prereq(n) - if i not in g_outputs_and_initializers - ]) - names.update(g_prereq) - # remove subgraph inputs from g_prereq since those are local-only - for i in g.input: - if i.name in names: - names.remove(i.name) - return names - - for n in self.tmp_mp_.graph.node: - prereq_for_node[n.output[0]] = get_prereq(n) - - # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate - sorted_nodes = [] - sorted_known_vi = set([ - i.name - for i in list(self.out_mp_.graph.input) + list( - self.out_mp_.graph.initializer) - ]) - if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): - # Loop/Scan will have some graph output in graph inputs, so don't do topological sort - sorted_nodes = self.out_mp_.graph.node - else: - while not all( - [o.name in sorted_known_vi for o in self.out_mp_.graph.output]): - old_sorted_nodes_len = len(sorted_nodes) - for node in self.out_mp_.graph.node: - if (node.output[0] not in sorted_known_vi) and all([ - i in sorted_known_vi - for i in prereq_for_node[node.output[0]] if i - ]): - sorted_known_vi.update(node.output) - sorted_nodes.append(node) - if old_sorted_nodes_len == len(sorted_nodes) and not all([ - o.name in sorted_known_vi - for o in self.out_mp_.graph.output - ]): - raise Exception('Invalid model with cyclic graph') - - for node in sorted_nodes: - assert all([i in self.known_vi_ for i in node.input if i]) - self._onnx_infer_single_node(node) - known_aten_op = False - if node.op_type in self.dispatcher_: - self.dispatcher_[node.op_type](node) - elif node.op_type in ['ConvTranspose']: - # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input - # before adding symbolic compute for them - # mark the output type as UNDEFINED to allow guessing of rank - vi = self.known_vi_[node.output[0]] - if len(vi.type.tensor_type.shape.dim) == 0: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten': - for attr in node.attribute: - # TODO: Is overload_name needed? - if attr.name == 'operator': - aten_op_name = attr.s.decode('utf-8') if isinstance( - attr.s, bytes) else attr.s - if aten_op_name in self.aten_op_dispatcher_: - known_aten_op = True - self.aten_op_dispatcher_[aten_op_name](node) - break - - if self.verbose_ > 2: - logger.debug(node.op_type + ': ' + node.name) - for i, name in enumerate(node.input): - logger.debug(' Input {}: {} {}'.format( - i, name, 'initializer' - if name in self.initializers_ else '')) - - # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] - # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case - if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', - 'MatMulInteger16', 'Where', 'Sum' - ]: - vi = self.known_vi_[node.output[0]] - out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] - for d in range(out_rank - (2 if node.op_type in [ - 'MatMul', 'MatMulInteger', 'MatMulInteger16' - ] else 0)): - in_dims = [ - s[len(s) - out_rank + d] for s in in_shapes - if len(s) + d >= out_rank - ] - if len(in_dims) > 1: - self._check_merged_dims(in_dims, allow_broadcast=True) - - for i_o in range(len(node.output)): - vi = self.known_vi_[node.output[i_o]] - out_type = vi.type - out_type_kind = out_type.WhichOneof('value') - - # do not process shape for non-tensors - if out_type_kind not in [ - 'tensor_type', 'sparse_tensor_type', None - ]: - if self.verbose_ > 2: - if out_type_kind == 'sequence_type': - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof( - 'value') - if 'tensor_type' == seq_cls_type: - logger.debug(' {}: sequence of {} {}'.format( - node.output[i_o], - str(get_shape_from_value_info(vi)), - onnx.TensorProto.DataType.Name( - vi.type.sequence_type.elem_type. - tensor_type.elem_type))) - else: - logger.debug(' {}: sequence of {}'.format( - node.output[i_o], seq_cls_type)) - else: - logger.debug(' {}: {}'.format(node.output[i_o], - out_type_kind)) - continue - - out_shape = get_shape_from_value_info(vi) - out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED - if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format( - node.output[i_o], - str(out_shape), - onnx.TensorProto.DataType.Name( - vi.type.tensor_type.elem_type))) - if node.output[i_o] in self.sympy_data_: - logger.debug(' Sympy Data: ' + str(self.sympy_data_[ - node.output[i_o]])) - - # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain - if (out_shape is not None and - (None in out_shape or - self._is_shape_contains_none_dim(out_shape)) - ) or out_type_undefined: - if self.auto_merge_: - if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', - 'MatMulInteger', 'MatMulInteger16', 'Concat', - 'Where', 'Sum', 'Equal', 'Less', 'Greater', - 'LessOrEqual', 'GreaterOrEqual' - ]: - shapes = [ - self._get_shape(node, i) - for i in range(len(node.input)) - ] - if node.op_type in [ - 'MatMul', 'MatMulInteger', 'MatMulInteger16' - ]: - if None in out_shape or self._is_shape_contains_none_dim( - out_shape): - if None in out_shape: - idx = out_shape.index(None) - else: - idx = out_shape.index( - self._is_shape_contains_none_dim( - out_shape)) - dim_idx = [ - len(s) - len(out_shape) + idx - for s in shapes - ] - # only support auto merge for MatMul for dim < rank-2 when rank > 2 - assert len( - shapes[0]) > 2 and dim_idx[0] < len( - shapes[0]) - 2 - assert len( - shapes[1]) > 2 and dim_idx[1] < len( - shapes[1]) - 2 - elif node.op_type == 'Expand': - # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) - shapes = [ - self._get_shape(node, 0), self._get_value(node, - 1) - ] - else: - shapes = [] - - if shapes: - for idx in range(len(out_shape)): - if out_shape[ - idx] is not None and not self._is_none_dim( - out_shape[idx]): - continue - # note that the broadcasting rule aligns from right to left - # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge - dim_idx = [ - len(s) - len(out_shape) + idx - for s in shapes - ] - if len(dim_idx) > 0: - self._add_suggested_merge([ - s[i] if is_literal(s[i]) else str(s[i]) - for s, i in zip(shapes, dim_idx) - if i >= 0 - ]) - self.run_ = True - else: - self.run_ = False - else: - self.run_ = False - - # create new dynamic dims for ops not handled by symbolic shape inference - if self.run_ == False and not node.op_type in self.dispatcher_ and not known_aten_op: - is_unknown_op = out_type_undefined and ( - out_shape is None or len(out_shape) == 0) - if is_unknown_op: - # unknown op to ONNX, maybe from higher opset or other domain - # only guess the output rank from input 0 when using guess_output_rank option - out_rank = self._get_shape_rank( - node, 0) if self.guess_output_rank_ else -1 - else: - # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape - out_rank = len(out_shape) - - if out_rank >= 0: - new_shape = self._new_symbolic_shape(out_rank, node, - i_o) - if out_type_undefined: - # guess output data type from input vi if not defined - out_dtype = self.known_vi_[node.input[ - 0]].type.tensor_type.elem_type - else: - # otherwise, use original data type - out_dtype = vi.type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, out_dtype, - get_shape_from_sympy_shape(new_shape))) - - if self.verbose_ > 0: - if is_unknown_op: - logger.debug( - "Possible unknown op: {} node: {}, guessing {} shape". - format(node.op_type, node.name, - vi.name)) - if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format( - node.output[i_o], - str(new_shape), - vi.type.tensor_type.elem_type)) - - self.run_ = True - continue # continue the inference after guess, no need to stop as no merge is needed - - if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug( - 'Stopping at incomplete shape inference at ' + - node.op_type + ': ' + node.name) - logger.debug('node inputs:') - for i in node.input: - logger.debug(self.known_vi_[i]) - logger.debug('node outputs:') - for o in node.output: - logger.debug(self.known_vi_[o]) - if self.auto_merge_ and not out_type_undefined: - logger.debug('Merging: ' + str( - self.suggested_merge_)) - return False - - self.run_ = False - return True - - def _update_output_from_vi(self): - for output in self.out_mp_.graph.output: - if output.name in self.known_vi_: - output.CopyFrom(self.known_vi_[output.name]) - - @staticmethod - def infer_shapes(in_mp, - int_max=2**31 - 1, - auto_merge=False, - guess_output_rank=False, - verbose=0): - onnx_opset = get_opset(in_mp) - if (not onnx_opset) or onnx_opset < 7: - logger.warning('Only support models of onnx opset 7 and above.') - return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose) - all_shapes_inferred = False - symbolic_shape_inference._preprocess(in_mp) - while symbolic_shape_inference.run_: - all_shapes_inferred = symbolic_shape_inference._infer_impl() - symbolic_shape_inference._update_output_from_vi() - if not all_shapes_inferred: - raise Exception("Incomplete symbolic shape inference") - return symbolic_shape_inference.out_mp_ - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, help='The input model file') - parser.add_argument('--output', help='The output model file') - parser.add_argument( - '--auto_merge', - help='Automatically merge symbolic dims when confliction happens', - action='store_true', - default=False) - parser.add_argument( - '--int_max', - help='maximum value for integer to be treated as boundless for ops like slice', - type=int, - default=2**31 - 1) - parser.add_argument( - '--guess_output_rank', - help='guess output rank to be the same as input 0 for unknown ops', - action='store_true', - default=False) - parser.add_argument( - '--verbose', - help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed', - type=int, - default=0) - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - logger.info('input model: ' + args.input) - if args.output: - logger.info('output model ' + args.output) - logger.info('Doing symbolic shape inference...') - out_mp = SymbolicShapeInference.infer_shapes( - onnx.load(args.input), args.int_max, args.auto_merge, - args.guess_output_rank, args.verbose) - if args.output and out_mp: - onnx.save(out_mp, args.output) - logger.info('Done!') diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh b/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh deleted file mode 100755 index ce2f24e5..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_opt.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -set -e - -if [ $# != 3 ];then - # ./local/onnx_opt.sh model.old.onnx model.opt.onnx "audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" - echo "usage: $0 onnx.model.in onnx.model.out input_shape " - exit 1 -fi - -# onnx optimizer -pip install onnx-simplifier - -in=$1 -out=$2 -input_shape=$3 - -check_n=3 - -onnxsim $in $out $check_n --dynamic-input-shape --input-shape $input_shape \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py deleted file mode 100755 index 5b85eef3..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_prune_model.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -W ignore::DeprecationWarning -# prune model by output names -import argparse -import copy -import sys - -import onnx - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - required=True, - help='Path of directory saved the input model.') - parser.add_argument( - '--output_names', - required=True, - nargs='+', - help='The outputs of pruned model.') - parser.add_argument( - '--save_file', required=True, help='Path to save the new onnx model.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - if len(set(args.output_names)) < len(args.output_names): - print( - "[ERROR] There's dumplicate name in --output_names, which is not allowed." - ) - sys.exit(-1) - - model = onnx.load(args.model) - - # collect all node outputs and graph output - output_tensor_names = set() - for node in model.graph.node: - for out in node.output: - # may contain model output - output_tensor_names.add(out) - - # for out in model.graph.output: - # output_tensor_names.add(out.name) - - for output_name in args.output_names: - if output_name not in output_tensor_names: - print( - "[ERROR] Cannot find output tensor name '{}' in onnx model graph.". - format(output_name)) - sys.exit(-1) - - output_node_indices = set() # has output names - output_to_node = dict() # all node outputs - for i, node in enumerate(model.graph.node): - for out in node.output: - output_to_node[out] = i - if out in args.output_names: - output_node_indices.add(i) - - # from outputs find all the ancestors - reserved_node_indices = copy.deepcopy( - output_node_indices) # nodes need to keep - reserved_inputs = set() # model input to keep - new_output_node_indices = copy.deepcopy(output_node_indices) - - while True and len(new_output_node_indices) > 0: - output_node_indices = copy.deepcopy(new_output_node_indices) - - new_output_node_indices = set() - - for out_node_idx in output_node_indices: - # backtrace to parenet - for ipt in model.graph.node[out_node_idx].input: - if ipt in output_to_node: - reserved_node_indices.add(output_to_node[ipt]) - new_output_node_indices.add(output_to_node[ipt]) - else: - reserved_inputs.add(ipt) - - num_inputs = len(model.graph.input) - num_outputs = len(model.graph.output) - num_nodes = len(model.graph.node) - print( - f"old graph has {num_inputs} inputs, {num_outputs} outpus, {num_nodes} nodes" - ) - print(f"{len(reserved_node_indices)} node to keep.") - - # del node not to keep - for idx in range(num_nodes - 1, -1, -1): - if idx not in reserved_node_indices: - del model.graph.node[idx] - - # del graph input not to keep - for idx in range(num_inputs - 1, -1, -1): - if model.graph.input[idx].name not in reserved_inputs: - del model.graph.input[idx] - - # del old graph outputs - for i in range(num_outputs): - del model.graph.output[0] - - # new graph output as user input - for out in args.output_names: - model.graph.output.extend([onnx.ValueInfoProto(name=out)]) - - # infer shape - try: - from onnx_infer_shape import SymbolicShapeInference - model = SymbolicShapeInference.infer_shapes( - model, - int_max=2**31 - 1, - auto_merge=True, - guess_output_rank=False, - verbose=1) - except Exception as e: - print(f"skip infer shape step: {e}") - - # check onnx model - onnx.checker.check_model(model) - # save onnx model - onnx.save(model, args.save_file) - print("[Finished] The new model saved in {}.".format(args.save_file)) - print("[DEBUG INFO] The inputs of new model: {}".format( - [x.name for x in model.graph.input])) - print("[DEBUG INFO] The outputs of new model: {}".format( - [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py b/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py deleted file mode 100755 index fc00a82e..00000000 --- a/speechx/examples/ds2_ol/onnx/local/onnx_rename_model.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -W ignore::DeprecationWarning -# rename node to new names -import argparse -import sys - -import onnx - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model', - required=True, - help='Path of directory saved the input model.') - parser.add_argument( - '--origin_names', - required=True, - nargs='+', - help='The original name you want to modify.') - parser.add_argument( - '--new_names', - required=True, - nargs='+', - help='The new name you want change to, the number of new_names should be same with the number of origin_names' - ) - parser.add_argument( - '--save_file', required=True, help='Path to save the new onnx model.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - if len(set(args.origin_names)) < len(args.origin_names): - print( - "[ERROR] There's dumplicate name in --origin_names, which is not allowed." - ) - sys.exit(-1) - - if len(set(args.new_names)) < len(args.new_names): - print( - "[ERROR] There's dumplicate name in --new_names, which is not allowed." - ) - sys.exit(-1) - - if len(args.new_names) != len(args.origin_names): - print( - "[ERROR] Number of --new_names must be same with the number of --origin_names." - ) - sys.exit(-1) - - model = onnx.load(args.model) - - # collect input and all node output - output_tensor_names = set() - for ipt in model.graph.input: - output_tensor_names.add(ipt.name) - - for node in model.graph.node: - for out in node.output: - output_tensor_names.add(out) - - for origin_name in args.origin_names: - if origin_name not in output_tensor_names: - print( - f"[ERROR] Cannot find tensor name '{origin_name}' in onnx model graph." - ) - sys.exit(-1) - - for new_name in args.new_names: - if new_name in output_tensor_names: - print( - "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed." - ) - sys.exit(-1) - - # rename graph input - for i, ipt in enumerate(model.graph.input): - if ipt.name in args.origin_names: - idx = args.origin_names.index(ipt.name) - model.graph.input[i].name = args.new_names[idx] - - # rename node input and output - for i, node in enumerate(model.graph.node): - for j, ipt in enumerate(node.input): - if ipt in args.origin_names: - idx = args.origin_names.index(ipt) - model.graph.node[i].input[j] = args.new_names[idx] - - for j, out in enumerate(node.output): - if out in args.origin_names: - idx = args.origin_names.index(out) - model.graph.node[i].output[j] = args.new_names[idx] - - # rename graph output - for i, out in enumerate(model.graph.output): - if out.name in args.origin_names: - idx = args.origin_names.index(out.name) - model.graph.output[i].name = args.new_names[idx] - - # check onnx model - onnx.checker.check_model(model) - - # save model - onnx.save(model, args.save_file) - - print("[Finished] The new model saved in {}.".format(args.save_file)) - print("[DEBUG INFO] The inputs of new model: {}".format( - [x.name for x in model.graph.input])) - print("[DEBUG INFO] The outputs of new model: {}".format( - [x.name for x in model.graph.output])) diff --git a/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py b/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py deleted file mode 100755 index 2c569236..00000000 --- a/speechx/examples/ds2_ol/onnx/local/ort_dyanmic_quant.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -from onnxruntime.quantization import quantize_dynamic -from onnxruntime.quantization import QuantType - - -def quantize_onnx_model(onnx_model_path, - quantized_model_path, - nodes_to_exclude=[]): - print("Starting quantization...") - - quantize_dynamic( - onnx_model_path, - quantized_model_path, - weight_type=QuantType.QInt8, - nodes_to_exclude=nodes_to_exclude) - - print(f"Quantized model saved to: {quantized_model_path}") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model-in", - type=str, - required=True, - help="ONNX model", ) - parser.add_argument( - "--model-out", - type=str, - required=True, - default='model.quant.onnx', - help="ONNX model", ) - parser.add_argument( - "--nodes-to-exclude", - type=str, - required=True, - help="nodes to exclude. e.g. conv,linear.", ) - - args = parser.parse_args() - - nodes_to_exclude = args.nodes_to_exclude.split(',') - quantize_onnx_model(args.model_in, args.model_out, nodes_to_exclude) - - -if __name__ == "__main__": - main() diff --git a/speechx/examples/ds2_ol/onnx/local/ort_opt.py b/speechx/examples/ds2_ol/onnx/local/ort_opt.py deleted file mode 100755 index 8e995bcf..00000000 --- a/speechx/examples/ds2_ol/onnx/local/ort_opt.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -import argparse - -import onnxruntime as ort - -# onnxruntime optimizer. -# https://onnxruntime.ai/docs/performance/graph-optimizations.html -# https://onnxruntime.ai/docs/api/python/api_summary.html#api - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - '--model_in', required=True, type=str, help='Path to onnx model.') - parser.add_argument( - '--opt_level', - required=True, - type=int, - default=0, - choices=[0, 1, 2], - help='Path to onnx model.') - parser.add_argument( - '--model_out', required=True, help='path to save the optimized model.') - parser.add_argument('--debug', default=False, help='output debug info.') - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_arguments() - - sess_options = ort.SessionOptions() - - # Set graph optimization level - print(f"opt level: {args.opt_level}") - if args.opt_level == 0: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC - elif args.opt_level == 1: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - else: - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - - # To enable model serialization after graph optimization set this - sess_options.optimized_model_filepath = args.model_out - - session = ort.InferenceSession(args.model_in, sess_options) diff --git a/speechx/examples/ds2_ol/onnx/local/tonnx.sh b/speechx/examples/ds2_ol/onnx/local/tonnx.sh deleted file mode 100755 index 10487230..00000000 --- a/speechx/examples/ds2_ol/onnx/local/tonnx.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -if [ $# != 4 ];then - # local/tonnx.sh data/exp/deepspeech2_online/checkpoints avg_1.jit.pdmodel avg_1.jit.pdiparams exp/model.onnx - echo "usage: $0 model_dir model_name param_name onnx_output_name" - exit 1 -fi - -dir=$1 -model=$2 -param=$3 -output=$4 - -pip install paddle2onnx -pip install onnx - -# https://github.com/PaddlePaddle/Paddle2ONNX#%E5%91%BD%E4%BB%A4%E8%A1%8C%E8%BD%AC%E6%8D%A2 - # opset10 support quantize -paddle2onnx --model_dir $dir \ - --model_filename $model \ - --params_filename $param \ - --save_file $output \ - --enable_dev_version True \ - --opset_version 11 \ - --enable_onnx_checker True - \ No newline at end of file diff --git a/speechx/examples/ds2_ol/onnx/path.sh b/speechx/examples/ds2_ol/onnx/path.sh deleted file mode 100755 index 97d48737..00000000 --- a/speechx/examples/ds2_ol/onnx/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -export PATH=$PATH:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/onnx/run.sh b/speechx/examples/ds2_ol/onnx/run.sh deleted file mode 100755 index 3dc5e910..00000000 --- a/speechx/examples/ds2_ol/onnx/run.sh +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash - -set -e - -. path.sh - -stage=0 -stop_stage=50 -tarfile=asr0_deepspeech2_online_wenetspeech_ckpt_1.0.2.model.tar.gz -#tarfile=asr0_deepspeech2_online_aishell_fbank161_ckpt_1.0.1.model.tar.gz -model_prefix=avg_10.jit -#model_prefix=avg_1.jit -model=${model_prefix}.pdmodel -param=${model_prefix}.pdiparams - -. utils/parse_options.sh - -data=data -exp=exp - -mkdir -p $data $exp - -dir=$data/exp/deepspeech2_online/checkpoints - -# wenetspeech or aishell -model_type=$(echo $tarfile | cut -d '_' -f 4) - -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then - test -f $data/$tarfile || wget -P $data -c https://paddlespeech.bj.bcebos.com/s2t/$model_type/asr0/$tarfile - - # wenetspeech ds2 model - pushd $data - tar zxvf $tarfile - popd - - # ds2 model demo inputs - pushd $exp - wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ds2_ol/onnx/static_ds2online_inputs.pickle - popd -fi - -input_file=$exp/static_ds2online_inputs.pickle -test -e $input_file - -if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ];then - # to onnx - ./local/tonnx.sh $dir $model $param $exp/model.onnx - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.onnx -fi - - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ] ;then - # ort graph optmize - ./local/ort_opt.py --model_in $exp/model.onnx --opt_level 0 --model_out $exp/model.ort.opt.onnx - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.ort.opt.onnx -fi - - -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ];then - # convert opset_num to 11 - ./local/onnx_convert_opset.py --target-opset 11 --model-file $exp/model.ort.opt.onnx --save-model $exp/model.optset11.onnx - - # quant model - nodes_to_exclude='p2o.Conv.0,p2o.Conv.2' - ./local/ort_dyanmic_quant.py --model-in $exp/model.optset11.onnx --model-out $exp/model.optset11.quant.onnx --nodes-to-exclude "${nodes_to_exclude}" - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.optset11.quant.onnx -fi - - -# aishell rnn hidden is 1024 -# wenetspeech rnn hiddn is 2048 -if [ $model_type == 'aishell' ];then - input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,1024 chunk_state_h_box:5,1,1024" -elif [ $model_type == 'wenetspeech' ];then - input_shape="audio_chunk:1,-1,161 audio_chunk_lens:1 chunk_state_c_box:5,1,2048 chunk_state_h_box:5,1,2048" -else - echo "not support: $model_type" - exit -1 -fi - - -if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ] ;then - # wenetspeech ds2 model execed 2GB limit, will error. - # simplifying onnx model - ./local/onnx_opt.sh $exp/model.onnx $exp/model.opt.onnx "$input_shape" - - ./local/infer_check.py --input_file $input_file --model_type $model_type --model_dir $dir --model_prefix $model_prefix --onnx_model $exp/model.opt.onnx -fi diff --git a/speechx/examples/ds2_ol/onnx/utils b/speechx/examples/ds2_ol/onnx/utils deleted file mode 120000 index c2519a9d..00000000 --- a/speechx/examples/ds2_ol/onnx/utils +++ /dev/null @@ -1 +0,0 @@ -../../../../utils/ \ No newline at end of file diff --git a/speechx/examples/ds2_ol/websocket/.gitignore b/speechx/examples/ds2_ol/websocket/.gitignore deleted file mode 100644 index bbd86a25..00000000 --- a/speechx/examples/ds2_ol/websocket/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -data -exp diff --git a/speechx/examples/ds2_ol/websocket/README.md b/speechx/examples/ds2_ol/websocket/README.md deleted file mode 100644 index 3fa84135..00000000 --- a/speechx/examples/ds2_ol/websocket/README.md +++ /dev/null @@ -1,78 +0,0 @@ -# Streaming DeepSpeech2 Server with WebSocket - -This example is about using `websocket` as streaming deepspeech2 server. For deepspeech2 model training please see [here](../../../../examples/aishell/asr0/). - -The websocket protocal is same to [PaddleSpeech Server](../../../../demos/streaming_asr_server/), -for detail of implementation please see [here](../../../speechx/protocol/websocket/). - - -## Source path.sh - -```bash -. path.sh -``` - -SpeechX bins is under `echo $SPEECHX_BUILD`, more info please see `path.sh`. - - -## Start WebSocket Server - -```bash -bash websoket_server.sh -``` - -The output is like below: - -```text -I1130 02:19:32.029882 12856 cmvn_json2kaldi_main.cc:39] cmvn josn path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/data/mean_std.json -I1130 02:19:32.032230 12856 cmvn_json2kaldi_main.cc:73] nframe: 907497 -I1130 02:19:32.032564 12856 cmvn_json2kaldi_main.cc:85] cmvn stats have write into: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark -I1130 02:19:32.032579 12856 cmvn_json2kaldi_main.cc:86] Binary: 1 -I1130 02:19:32.798342 12937 feature_pipeline.h:53] cmvn file: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/cmvn.ark -I1130 02:19:32.798542 12937 feature_pipeline.h:58] dither: 0 -I1130 02:19:32.798583 12937 feature_pipeline.h:60] frame shift ms: 10 -I1130 02:19:32.798588 12937 feature_pipeline.h:62] feature type: linear -I1130 02:19:32.798596 12937 feature_pipeline.h:80] frame length ms: 20 -I1130 02:19:32.798601 12937 feature_pipeline.h:88] subsampling rate: 4 -I1130 02:19:32.798606 12937 feature_pipeline.h:90] nnet receptive filed length: 7 -I1130 02:19:32.798611 12937 feature_pipeline.h:92] nnet chunk size: 1 -I1130 02:19:32.798615 12937 feature_pipeline.h:94] frontend fill zeros: 0 -I1130 02:19:32.798630 12937 nnet_itf.h:52] subsampling rate: 4 -I1130 02:19:32.798635 12937 nnet_itf.h:54] model path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdmodel -I1130 02:19:32.798640 12937 nnet_itf.h:57] param path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/model/exp/deepspeech2_online/checkpoints//avg_1.jit.pdiparams -I1130 02:19:32.798643 12937 nnet_itf.h:59] DS2 param: -I1130 02:19:32.798647 12937 nnet_itf.h:61] cache names: chunk_state_h_box,chunk_state_c_box -I1130 02:19:32.798652 12937 nnet_itf.h:63] cache shape: 5-1-1024,5-1-1024 -I1130 02:19:32.798656 12937 nnet_itf.h:65] input names: audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box -I1130 02:19:32.798660 12937 nnet_itf.h:67] output names: softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 -I1130 02:19:32.798664 12937 ctc_tlg_decoder.h:41] fst path: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//TLG.fst -I1130 02:19:32.798669 12937 ctc_tlg_decoder.h:42] fst symbole table: /workspace/zhanghui/PaddleSpeech/speechx/examples/ds2_ol/websocket/data/wfst//words.txt -I1130 02:19:32.798673 12937 ctc_tlg_decoder.h:47] LatticeFasterDecoder max active: 7500 -I1130 02:19:32.798677 12937 ctc_tlg_decoder.h:49] LatticeFasterDecoder beam: 15 -I1130 02:19:32.798681 12937 ctc_tlg_decoder.h:50] LatticeFasterDecoder lattice_beam: 7.5 -I1130 02:19:32.798708 12937 websocket_server_main.cc:37] Listening at port 8082 -``` - -## Start WebSocket Client - -```bash -bash websocket_client.sh -``` - -This script using AISHELL-1 test data to call websocket server. - -The input is specific by `--wav_rspecifier=scp:$data/$aishell_wav_scp`. - -The `scp` file which look like this: -```text -# head data/split1/1/aishell_test.scp -BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav -BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav -... -BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav -``` - -If you want to recognize one wav, you can make `scp` file like this: -```text -key path/to/wav/file -``` diff --git a/speechx/examples/ds2_ol/websocket/path.sh b/speechx/examples/ds2_ol/websocket/path.sh deleted file mode 100755 index 6dd6bddb..00000000 --- a/speechx/examples/ds2_ol/websocket/path.sh +++ /dev/null @@ -1,14 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -SPEECHX_BIN=$SPEECHX_BUILD/protocol/websocket:$SPEECHX_BUILD/frontend/audio -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/websocket/websocket_client.sh b/speechx/examples/ds2_ol/websocket/websocket_client.sh deleted file mode 100755 index a508adfb..00000000 --- a/speechx/examples/ds2_ol/websocket/websocket_client.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data - -# output -aishell_wav_scp=aishell_test.scp -if [ ! -d $data/test ]; then - pushd $data - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip - unzip aishell_test.zip - popd - - realpath $data/test/*/*.wav > $data/wavlist - awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id - paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp -fi - -export GLOG_logtostderr=1 - -# websocket client -websocket_client_main \ - --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.5 diff --git a/speechx/examples/ds2_ol/websocket/websocket_server.sh b/speechx/examples/ds2_ol/websocket/websocket_server.sh deleted file mode 100755 index 18d29857..00000000 --- a/speechx/examples/ds2_ol/websocket/websocket_server.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -set +x -set -e - -. path.sh - -# 1. compile -if [ ! -d ${SPEECHX_EXAMPLES} ]; then - pushd ${SPEECHX_ROOT} - bash build.sh - popd -fi - -# input -mkdir -p data -data=$PWD/data -ckpt_dir=$data/model -model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ -vocb_dir=$ckpt_dir/data/lang_char/ - - -if [ ! -f $ckpt_dir/data/mean_std.json ]; then - mkdir -p $ckpt_dir - pushd $ckpt_dir - wget -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - popd -fi - -export GLOG_logtostderr=1 - -# 3. gen cmvn -cmvn=$data/cmvn.ark -cmvn_json2kaldi_main --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn - - -wfst=$data/wfst/ -mkdir -p $wfst -if [ ! -f $wfst/aishell_graph.zip ]; then - pushd $wfst - wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip aishell_graph.zip - mv aishell_graph/* $wfst - popd -fi - -# 5. test websocket server -websocket_server_main \ - --cmvn_file=$cmvn \ - --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ - --word_symbol_table=$wfst/words.txt \ - --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ - --graph_path=$wfst/TLG.fst --max_active=7500 \ - --acoustic_scale=1.2 diff --git a/speechx/examples/u2pp_ol/wenetspeech/path.sh b/speechx/examples/u2pp_ol/wenetspeech/path.sh index ec278bd3..9518db11 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/path.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/path.sh @@ -3,7 +3,7 @@ unset GREP_OPTIONS SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx/asr SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin @@ -12,7 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer +export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/../common/frontend/audio:$SPEECHX_BUILD/recognizer PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index 93014fb9..b2f50708 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -1,55 +1,22 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) - set(srcs) - -if (USING_DS2) list(APPEND srcs - ctc_decoders/decoder_utils.cpp - ctc_decoders/path_trie.cpp - ctc_decoders/scorer.cpp - ctc_beam_search_decoder.cc - ctc_tlg_decoder.cc + ctc_prefix_beam_search_decoder.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - ctc_prefix_beam_search_decoder.cc - ) -endif() add_library(decoder STATIC ${srcs}) -target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) +target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) # test -if (USING_DS2) - set(BINS - ctc_beam_search_decoder_main - nnet_logprob_decoder_main - ctc_tlg_decoder_main - ) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - ctc_prefix_beam_search_decoder_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() +set(TEST_BINS + ctc_prefix_beam_search_decoder_main +) -endif() +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc deleted file mode 100644 index 6e3a0d13..00000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -#include "decoder/ctc_beam_search_decoder.h" - -#include "base/common.h" -#include "decoder/ctc_decoders/decoder_utils.h" -#include "utils/file_utils.h" - -namespace ppspeech { - -using std::vector; -using FSTMATCH = fst::SortedMatcher; - -CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) - : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) { - LOG(INFO) << "dict path: " << opts_.dict_file; - if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { - LOG(INFO) << "load the dict failed"; - } - LOG(INFO) << "read the vocabulary success, dict size: " - << vocabulary_.size(); - - LOG(INFO) << "language model path: " << opts_.lm_path; - if (opts_.lm_path != "") { - init_ext_scorer_ = std::make_shared( - opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); - } - - CHECK_EQ(opts_.blank, 0); - - auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - space_id_ = it - vocabulary_.begin(); - // if no space in vocabulary - if (static_cast(space_id_) >= vocabulary_.size()) { - space_id_ = -2; - } -} - -void CTCBeamSearch::Reset() { - // num_frame_decoded_ = 0; - // ResetPrefixes(); - InitDecoder(); -} - -void CTCBeamSearch::InitDecoder() { - num_frame_decoded_ = 0; - // ResetPrefixes(); - prefixes_.clear(); - - root_ = std::make_shared(); - root_->score = root_->log_prob_b_prev = 0.0; - prefixes_.push_back(root_.get()); - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - auto fst_dict = - static_cast(init_ext_scorer_->dictionary); - fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); - root_->set_dictionary(dict_ptr); - - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root_->set_matcher(matcher); - } -} - -void CTCBeamSearch::Decode( - std::shared_ptr decodable) { - return; -} - -// todo rename, refactor -void CTCBeamSearch::AdvanceDecode( - const std::shared_ptr& decodable) { - while (1) { - vector> likelihood; - vector frame_prob; - bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); - if (flag == false) break; - likelihood.push_back(frame_prob); - AdvanceDecoding(likelihood); - } -} - -void CTCBeamSearch::ResetPrefixes() { - for (size_t i = 0; i < prefixes_.size(); i++) { - if (prefixes_[i] != nullptr) { - delete prefixes_[i]; - prefixes_[i] = nullptr; - } - } - prefixes_.clear(); -} - -int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, - const vector& nbest_words) { - kaldi::Timer timer; - AdvanceDecoding(probs); - LOG(INFO) << "ctc decoding elapsed time(s) " - << static_cast(timer.Elapsed()) / 1000.0f; - return 0; -} - -vector> CTCBeamSearch::GetNBestPath(int n) { - int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size); - return get_beam_search_result(prefixes_, vocabulary_, beam_size); -} - -vector> CTCBeamSearch::GetNBestPath() { - return GetNBestPath(-1); -} - -string CTCBeamSearch::GetBestPath() { - std::vector> result; - result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); - return result[0].second; -} - -string CTCBeamSearch::GetFinalBestPath() { - CalculateApproxScore(); - LMRescore(); - return GetBestPath(); -} - -void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { - size_t num_time_steps = probs.size(); - size_t beam_size = opts_.beam_size; - double cutoff_prob = opts_.cutoff_prob; - size_t cutoff_top_n = opts_.cutoff_top_n; - - vector> probs_seq(probs.size(), - vector(probs[0].size(), 0)); - - int row = probs.size(); - int col = probs[0].size(); - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j++) { - probs_seq[i][j] = static_cast(probs[i][j]); - } - } - - for (size_t time_step = 0; time_step < num_time_steps; time_step++) { - const auto& prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (init_ext_scorer_ != nullptr) { - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort(prefixes_.begin(), - prefixes_.begin() + num_prefixes_, - prefix_compare); - - if (num_prefixes_ == 0) { - continue; - } - min_cutoff = prefixes_[num_prefixes_ - 1]->score + - std::log(prob[opts_.blank]) - - std::max(0.0, init_ext_scorer_->beta); - - full_beam = (num_prefixes_ == beam_size); - } - - vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - - // loop over chars - size_t log_prob_idx_len = log_prob_idx.size(); - for (size_t index = 0; index < log_prob_idx_len; index++) { - SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); - } - - prefixes_.clear(); - - // update log probs - root_->iterate_to_vec(prefixes_); - // only preserve top beam_size prefixes_ - if (prefixes_.size() >= beam_size) { - std::nth_element(prefixes_.begin(), - prefixes_.begin() + beam_size, - prefixes_.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes_.size(); ++i) { - prefixes_[i]->remove(); - } - } // end if - num_frame_decoded_++; - } // end for probs_seq -} - -int32 CTCBeamSearch::SearchOneChar( - const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff) { - size_t beam_size = opts_.beam_size; - const auto& c = log_prob_idx.first; - const auto& log_prob_c = log_prob_idx.second; - size_t prefixes_len = std::min(prefixes_.size(), beam_size); - - for (size_t i = 0; i < prefixes_len; ++i) { - auto prefix = prefixes_[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - - if (c == opts_.blank) { - prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - - // repeated character - if (c == prefix->character) { - // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); - } - - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (init_ext_scorer_ != nullptr && - (c == space_id_ || init_ext_scorer_->is_character_based())) { - PathTrie* prefix_to_score = nullptr; - // skip scoring the space - if (init_ext_scorer_->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - vector ngram; - ngram = init_ext_scorer_->make_ngram(prefix_to_score); - // lm score: p_{lm}(W)^{\alpha} + \beta - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - log_p += score; - log_p += init_ext_scorer_->beta; - } - // p_{nb}(l;x_{1:t}) - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - return 0; -} - -void CTCBeamSearch::CalculateApproxScore() { - size_t beam_size = opts_.beam_size; - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort( - prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - double approx_ctc = prefixes_[i]->score; - if (init_ext_scorer_ != nullptr) { - vector output; - prefixes_[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = init_ext_scorer_->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; - // remove language model weight: - approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * - init_ext_scorer_->alpha; - } - prefixes_[i]->approx_ctc = approx_ctc; - } -} - -void CTCBeamSearch::LMRescore() { - size_t beam_size = opts_.beam_size; - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - auto prefix = prefixes_[i]; - if (!prefix->is_empty() && prefix->character != space_id_) { - float score = 0.0; - vector ngram = init_ext_scorer_->make_ngram(prefix); - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - score += init_ext_scorer_->beta; - prefix->score += score; - } - } - } -} - -} // namespace ppspeech diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h deleted file mode 100644 index f06d88e3..00000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#pragma once - -#include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_decoders/path_trie.h" -#include "decoder/ctc_decoders/scorer.h" -#include "decoder/decoder_itf.h" - -namespace ppspeech { - -class CTCBeamSearch : public DecoderBase { - public: - explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); - ~CTCBeamSearch() {} - - void InitDecoder(); - - void Reset(); - - void AdvanceDecode( - const std::shared_ptr& decodable); - - void Decode(std::shared_ptr decodable); - - std::string GetBestPath(); - std::vector> GetNBestPath(); - std::vector> GetNBestPath(int n); - std::string GetFinalBestPath(); - - std::string GetPartialResult() { - CHECK(false) << "Not implement."; - return {}; - } - - int DecodeLikelihoods(const std::vector>& probs, - const std::vector& nbest_words); - - private: - void ResetPrefixes(); - - int32 SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff); - void CalculateApproxScore(); - void LMRescore(); - void AdvanceDecoding(const std::vector>& probs); - - CTCBeamSearchOptions opts_; - std::shared_ptr init_ext_scorer_; // todo separate later - std::vector vocabulary_; // todo remove later - int space_id_; - std::shared_ptr root_; - std::vector prefixes_; - - DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc deleted file mode 100644 index ab0376b6..00000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "", "language model"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(subsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test ds2 online decoder by feeding speech feature -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - CHECK_NE(FLAGS_result_wspecifier, ""); - CHECK_NE(FLAGS_feature_rspecifier, ""); - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - std::string model_path = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_path; - LOG(INFO) << "model param: " << model_params; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - decoder.InitDecoder(); - - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } - std::string result; - result = decoder.GetFinalBestPath(); - decodable->Reset(); - decoder.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore b/speechx/speechx/asr/decoder/ctc_decoders/.gitignore deleted file mode 100644 index 0b1046ae..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -ThreadPool/ -build/ -dist/ -kenlm/ -openfst-1.6.3/ -openfst-1.6.3.tar.gz -swig_decoders.egg-info/ -decoders_wrap.cxx -swig_decoders.py diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp deleted file mode 100644 index ebea5c22..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp +++ /dev/null @@ -1,607 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_beam_search_decoder.h" - -#include -#include -#include -#include -#include -#include - -#include "ThreadPool.h" -#include "fst/fstlib.h" - -#include "decoder_utils.h" -#include "path_trie.h" - -using FSTMATCH = fst::SortedMatcher; - - -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - PathTrie root; - root.score = root.log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); - - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } - - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute approximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - return get_beam_search_result(prefixes, vocabulary, beam_size); -} - - -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoding, - probs_split[i], - vocabulary, - beam_size, - cutoff_prob, - cutoff_top_n, - ext_scorer, - blank_id)); - } - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) { - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root->set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root->set_matcher(matcher); - } -} - -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - // - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - prefixes.clear(); - // update log probs - - root->iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - return; -} - - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer) { - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - std::vector> res = - get_beam_search_result(prefixes, vocabulary, beam_size); - - // pay back the last word of each prefix that doesn't end with space (for - // decoding by chunk) - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score -= score; - } - } - } - return res; -} - - -void free_storage(std::unique_ptr &storage) { - storage = nullptr; -} - - -CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {} - -CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch( - const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) - : batch_size(batch_size), - beam_size(beam_size), - num_processes(num_processes), - cutoff_prob(cutoff_prob), - cutoff_top_n(cutoff_top_n), - ext_scorer(ext_scorer), - blank_id(blank_id) { - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - this->vocabulary = vocabulary; - for (size_t i = 0; i < batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, ext_scorer); - } -}; - -/** - * Input - * probs_split: shape [B, T, D] - */ -void CtcBeamSearchDecoderBatch::next( - const std::vector>> &probs_split, - const std::vector &has_value) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - size_t num_has_value = 0; - for (int i = 0; i < has_value.size(); i++) - if (has_value[i] == "true") num_has_value += 1; - ThreadPool pool(std::min(num_processes, num_has_value)); - // number of samples - size_t probs_num = probs_split.size(); - VALID_CHECK_EQ(this->batch_size, - probs_num, - "The batch size of the current input data should be same " - "with the input data before"); - - // enqueue the tasks of decoding - std::vector> res; - for (size_t i = 0; i < batch_size; ++i) { - if (has_value[i] == "true") { - res.emplace_back(pool.enqueue( - ctc_beam_search_decode_chunk, - std::ref(this->decoder_storage_vector[i]->root), - std::ref(this->decoder_storage_vector[i]->prefixes), - probs_split[i], - this->vocabulary, - this->beam_size, - this->cutoff_prob, - this->cutoff_top_n, - this->ext_scorer, - this->blank_id)); - } - } - - for (size_t i = 0; i < batch_size; ++i) { - res[i].get(); - } - return; -}; - -/** - * Return - * batch_result: shape[B, beam_size,(-approx_ctc score, string)] - */ -std::vector>> -CtcBeamSearchDecoderBatch::decode() { - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < this->batch_size; ++i) { - res.emplace_back( - pool.enqueue(get_decode_result, - std::ref(this->decoder_storage_vector[i]->prefixes), - this->vocabulary, - this->beam_size, - this->ext_scorer)); - } - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < this->batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - - -/** - * reset the state of ctcBeamSearchDecoderBatch - */ -void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n) { - this->batch_size = batch_size; - this->beam_size = beam_size; - this->num_processes = num_processes; - this->cutoff_prob = cutoff_prob; - this->cutoff_top_n = cutoff_top_n; - - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector> res; - size_t storage_size = decoder_storage_vector.size(); - for (size_t i = 0; i < storage_size; i++) { - res.emplace_back(pool.enqueue( - free_storage, std::ref(this->decoder_storage_vector[i]))); - } - for (size_t i = 0; i < storage_size; ++i) { - res[i].get(); - } - std::vector>().swap( - decoder_storage_vector); - for (size_t i = 0; i < this->batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, this->ext_scorer); - } -} \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h deleted file mode 100644 index 92d2b855..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_BEAM_SEARCH_DECODER_H_ -#define CTC_BEAM_SEARCH_DECODER_H_ - -#include -#include -#include - -#include "scorer.h" - -/* CTC Beam Search Decoder - - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. -*/ -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - - -/* CTC Beam Search Decoder for batch data - - * Parameters: - * probs_seq: 3-D vector that each element is a 2-D vector that can be used - * by ctc_beam_search_decoder(). - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * num_processes: Number of threads for beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A 2-D vector that each element is a vector of beam search decoding - * result for one audio sample. -*/ -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - -/** - * Store the root and prefixes for decoder - */ - -class CtcBeamSearchDecoderStorage { - public: - PathTrie *root = nullptr; - std::vector prefixes; - - CtcBeamSearchDecoderStorage() { - // init prefixes' root - this->root = new PathTrie(); - this->root->log_prob_b_prev = 0.0; - // The score of root is in log scale.Since the prob=1.0, the prob score - // in log scale is 0.0 - this->root->score = root->log_prob_b_prev; - // std::vector prefixes; - this->prefixes.push_back(root); - }; - - ~CtcBeamSearchDecoderStorage() { - if (root != nullptr) { - delete root; - root = nullptr; - } - }; -}; - -/** - * The ctc beam search decoder, support batchsize >= 1 - */ -class CtcBeamSearchDecoderBatch { - public: - CtcBeamSearchDecoderBatch(const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - - ~CtcBeamSearchDecoderBatch(); - void next(const std::vector>> &probs_split, - const std::vector &has_value); - - std::vector>> decode(); - - void reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n); - - private: - std::vector vocabulary; - size_t batch_size; - size_t beam_size; - size_t num_processes; - double cutoff_prob; - size_t cutoff_top_n; - Scorer *ext_scorer; - size_t blank_id; - std::vector> - decoder_storage_vector; -}; - -/** - * function for chunk decoding - */ -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer); - -/** - * free the CtcBeamSearchDecoderStorage - */ -void free_storage(std::unique_ptr &storage); - -/** - * initialize the root - */ -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer); - -#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp deleted file mode 100644 index 6aa3c996..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_greedy_decoder.h" -#include "decoder_utils.h" - -std::string ctc_greedy_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // size_t blank_id = vocabulary.size(); - - std::vector max_idx_vec(num_time_steps, 0); - std::vector idx_vec; - for (size_t i = 0; i < num_time_steps; ++i) { - double max_prob = 0.0; - size_t max_idx = 0; - const std::vector &probs_step = probs_seq[i]; - for (size_t j = 0; j < probs_step.size(); ++j) { - if (max_prob < probs_step[j]) { - max_idx = j; - max_prob = probs_step[j]; - } - } - // id with maximum probability in current time step - max_idx_vec[i] = max_idx; - // deduplicate - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { - idx_vec.push_back(max_idx_vec[i]); - } - } - - std::string best_path_result; - for (size_t i = 0; i < idx_vec.size(); ++i) { - if (idx_vec[i] != blank_id) { - std::string ch = vocabulary[idx_vec[i]]; - best_path_result += (ch == kSPACE) ? tSPACE : ch; - } - } - return best_path_result; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h deleted file mode 100644 index 4451600d..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_GREEDY_DECODER_H -#define CTC_GREEDY_DECODER_H - -#include -#include - -/* CTC Greedy (Best Path) Decoder - * - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * Return: - * The decoding result in string - */ -std::string ctc_greedy_decoding( - const std::vector>& probs_seq, - const std::vector& vocabulary, - size_t blank_id); - -#endif // CTC_GREEDY_DECODER_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp deleted file mode 100644 index c7ef6542..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder_utils.h" - -#include -#include -#include - -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n) { - std::vector> prob_idx; - for (size_t i = 0; i < prob_step.size(); ++i) { - prob_idx.push_back(std::pair(i, prob_step[i])); - } - // pruning of vocabulary - size_t cutoff_len = prob_step.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { - std::sort(prob_idx.begin(), - prob_idx.end(), - pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob += prob_idx[i].second; - cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) - break; - } - } - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } - return log_prob_idx; -} - - -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (size_t j = 0; j < output.size(); j++) { - std::string ch = vocabulary[output[j]]; - output_str += (ch == kSPACE) ? tSPACE : ch; - } - std::pair output_pair( - -space_prefixes[i]->approx_ctc, output_str); - output_vecs.emplace_back(output_pair); - } - - return output_vecs; -} - -size_t get_utf8_str_len(const std::string &str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; -} - -std::vector split_utf8_str(const std::string &str) { - std::vector result; - std::string out_str; - - for (char c : str) { - if ((c & 0xc0) != 0x80) // new UTF-8 character - { - if (!out_str.empty()) { - result.push_back(out_str); - out_str.clear(); - } - } - - out_str.append(1, c); - } - result.push_back(out_str); - return result; -} - -std::vector split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; - } - return result; -} - -bool prefix_compare(const PathTrie *x, const PathTrie *y) { - if (x->score == y->score) { - if (x->character == y->character) { - return false; - } else { - return (x->character < y->character); - } - } else { - return x->score > y->score; - } -} - -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); -} - -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary) { - auto characters = split_utf8_str(word); - - std::vector int_word; - - for (auto &c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } - } - } - - if (add_space) { - int_word.push_back(SPACE_ID); - } - - add_word_to_fst(int_word, dictionary); - return true; // return with successful adding -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h deleted file mode 100644 index 09874155..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DECODER_UTILS_H_ -#define DECODER_UTILS_H_ - -#include -#include -#include "fst/log.h" -#include "path_trie.h" - -const std::string kSPACE = ""; -const std::string tSPACE = " "; -const float NUM_FLT_INF = std::numeric_limits::max(); -const float NUM_FLT_MIN = std::numeric_limits::min(); - -// inline function for validation check -inline void check( - bool x, const char *expr, const char *file, int line, const char *err) { - if (!x) { - std::cout << "[" << file << ":" << line << "] "; - LOG(FATAL) << "\"" << expr << "\" check failed. " << err; - } -} - -#define VALID_CHECK(x, info) \ - check(static_cast(x), #x, __FILE__, __LINE__, info) -#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) -#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) -#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) - - -// Function template for comparing two pairs -template -bool pair_comp_first_rev(const std::pair &a, - const std::pair &b) { - return a.first > b.first; -} - -// Function template for comparing two pairs -template -bool pair_comp_second_rev(const std::pair &a, - const std::pair &b) { - return a.second > b.second; -} - -// Return the sum of two probabilities in log scale -template -T log_sum_exp(const T &x, const T &y) { - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; -} - -// Get pruned probability vector for each time step's beam search -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n); - -// Get beam search result from prefixes in trie tree -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size); - -// Functor for prefix comparsion -bool prefix_compare(const PathTrie *x, const PathTrie *y); - -/* Get length of utf8 encoding string - * See: http://stackoverflow.com/a/4063229 - */ -size_t get_utf8_str_len(const std::string &str); - -/* Split a string into a list of strings on a given string - * delimiter. NB: delimiters on beginning / end of string are - * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. - */ -std::vector split_str(const std::string &s, - const std::string &delim); - -/* Splits string into vector of strings representing - * UTF-8 characters (not same as chars) - */ -std::vector split_utf8_str(const std::string &str); - -// Add a word in index to the dicionary of fst -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary); - -// Add a word in string to dictionary -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary); -#endif // DECODER_UTILS_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp deleted file mode 100644 index 777ca052..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "path_trie.h" - -#include -#include -#include -#include -#include - -#include "decoder_utils.h" - -PathTrie::PathTrie() { - log_prob_b_prev = -NUM_FLT_INF; - log_prob_nb_prev = -NUM_FLT_INF; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - score = -NUM_FLT_INF; - - ROOT_ = -1; - character = ROOT_; - exists_ = true; - parent = nullptr; - - dictionary_ = nullptr; - dictionary_state_ = 0; - has_dictionary_ = false; - - matcher_ = nullptr; -} - -PathTrie::~PathTrie() { - for (auto child : children_) { - delete child.second; - child.second = nullptr; - } -} - -PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = children_.begin(); - for (child = children_.begin(); child != children_.end(); ++child) { - if (child->first == new_char) { - break; - } - } - if (child != children_.end()) { - if (!child->second->exists_) { - child->second->exists_ = true; - child->second->log_prob_b_prev = -NUM_FLT_INF; - child->second->log_prob_nb_prev = -NUM_FLT_INF; - child->second->log_prob_b_cur = -NUM_FLT_INF; - child->second->log_prob_nb_cur = -NUM_FLT_INF; - } - return (child->second); - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); - } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->dictionary_state_ = matcher_->Value().nextstate; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } -} - -PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, ROOT_); -} - -PathTrie* PathTrie::get_path_vec(std::vector& output, - int stop, - size_t max_steps) { - if (character == stop || character == ROOT_ || output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - return this; - } else { - output.push_back(character); - return parent->get_path_vec(output, stop, max_steps); - } -} - -void PathTrie::iterate_to_vec(std::vector& output) { - if (exists_) { - log_prob_b_prev = log_prob_b_cur; - log_prob_nb_prev = log_prob_nb_cur; - - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - - score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - output.push_back(this); - } - for (auto child : children_) { - child.second->iterate_to_vec(output); - } -} - -void PathTrie::remove() { - exists_ = false; - if (children_.size() == 0) { - if (parent != nullptr) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); - child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; - } - } - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); - } - } - delete this; - } -} - - -void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - dictionary_ = dictionary; - dictionary_state_ = dictionary->Start(); - has_dictionary_ = true; -} - -using FSTMATCH = fst::SortedMatcher; -void PathTrie::set_matcher(std::shared_ptr matcher) { - matcher_ = matcher; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h deleted file mode 100644 index 5193e0a4..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PATH_TRIE_H -#define PATH_TRIE_H - -#include -#include -#include -#include -#include - -#include "fst/fstlib.h" - -/* Trie tree for prefix storing and manipulating, with a dictionary in - * finite-state transducer for spelling correction. - */ -class PathTrie { - public: - PathTrie(); - ~PathTrie(); - - // get new prefix after appending new char - PathTrie* get_path_trie(int new_char, bool reset = true); - - // get the prefix in index from root to current node - PathTrie* get_path_vec(std::vector& output); - - // get the prefix in index from some stop node to current nodel - PathTrie* get_path_vec( - std::vector& output, - int stop, - size_t max_steps = std::numeric_limits::max()); - - // update log probs - void iterate_to_vec(std::vector& output); - - // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); - - void set_matcher(std::shared_ptr>); - - bool is_empty() { return ROOT_ == character; } - - // remove current path from root - void remove(); - - float log_prob_b_prev; - float log_prob_nb_prev; - float log_prob_b_cur; - float log_prob_nb_cur; - float score; - float approx_ctc; - int character; - PathTrie* parent; - - private: - int ROOT_; - bool exists_; - bool has_dictionary_; - - std::vector> children_; - - // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; - // true if finding ars in FST - std::shared_ptr> matcher_; -}; - -#endif // PATH_TRIE_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp b/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp deleted file mode 100644 index 6e7f68cf..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp +++ /dev/null @@ -1,232 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#include "scorer.h" - -#include -#include - -#include "lm/config.hh" -#include "lm/model.hh" -#include "lm/state.hh" - -#include "decoder_utils.h" - -using namespace lm::ngram; -// if your platform is windows ,you need add the define -#define F_OK 0 -Scorer::Scorer(double alpha, - double beta, - const std::string& lm_path, - const std::vector& vocab_list) { - this->alpha = alpha; - this->beta = beta; - - dictionary = nullptr; - is_character_based_ = true; - language_model_ = nullptr; - - max_order_ = 0; - dict_size_ = 0; - SPACE_ID_ = -1; - - setup(lm_path, vocab_list); -} - -Scorer::~Scorer() { - if (language_model_ != nullptr) { - delete static_cast(language_model_); - } - if (dictionary != nullptr) { - delete static_cast(dictionary); - } -} - -void Scorer::setup(const std::string& lm_path, - const std::vector& vocab_list) { - // load language model - load_lm(lm_path); - // set char map for scorer - set_char_map(vocab_list); - // fill the dictionary for FST - if (!is_character_based()) { - fill_dictionary(true); - } -} - -void Scorer::load_lm(const std::string& lm_path) { - const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); - - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - language_model_ = lm::ngram::LoadVirtual(filename, config); - max_order_ = static_cast(language_model_)->Order(); - vocabulary_ = enumerate.vocabulary; - for (size_t i = 0; i < vocabulary_.size(); ++i) { - if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && - vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && - get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - is_character_based_ = false; - } - } -} - -double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(language_model_); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; - } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; - } - // return log10 prob - return cond_prob; -} - -double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); - } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); -} - -double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > max_order_); - double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); - } - return score; -} - -void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; -} - -std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += char_list_[ind]; - } - return word; -} - -std::vector Scorer::split_labels(const std::vector& labels) { - if (labels.empty()) return {}; - - std::string s = vec2str(labels); - std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; -} - -void Scorer::set_char_map(const std::vector& char_list) { - char_list_ = char_list; - char_map_.clear(); - - // Set the char map for the FST for spelling correction - for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == kSPACE) { - SPACE_ID_ = i; - } - // The initial state of FST is state 0, hence the index of chars in - // the FST should start from 1 to avoid the conflict with the initial - // state, otherwise wrong decoding results would be given. - char_map_[char_list_[i]] = i + 1; - } -} - -std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; - - for (int order = 0; order < max_order_; order++) { - std::vector prefix_vec; - - if (is_character_based_) { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); - current_node = new_node->parent; // Skipping spaces - } - - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); - - if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; - } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; -} - -void Scorer::fill_dictionary(bool add_space) { - fst::StdVectorFst dictionary; - // For each unigram convert to ints and put in trie - int dict_size = 0; - for (const auto& word : vocabulary_) { - bool added = add_word_to_dictionary( - word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); - dict_size += added ? 1 : 0; - } - - dict_size_ = dict_size; - - /* Simplify FST - - * This gets rid of "epsilon" transitions in the FST. - * These are transitions that don't require a string input to be taken. - * Getting rid of them is necessary to make the FST deterministic, but - * can greatly increase the size of the FST - */ - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; - - /* This makes the FST deterministic, meaning for any string input there's - * only one possible state the FST could be in. It is assumed our - * dictionary is deterministic when using it. - * (lest we'd have to check for multiple transitions at each state) - */ - fst::Determinize(dictionary, new_dict); - - /* Finds the simplest equivalent fst. This is unnecessary but decreases - * memory usage of the dictionary - */ - fst::Minimize(new_dict); - this->dictionary = new_dict; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h b/speechx/speechx/asr/decoder/ctc_decoders/scorer.h deleted file mode 100644 index 08e109b7..00000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h +++ /dev/null @@ -1,114 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#ifndef SCORER_H_ -#define SCORER_H_ - -#include -#include -#include -#include - -#include "lm/enumerate_vocab.hh" -#include "lm/virtual_interface.hh" -#include "lm/word_index.hh" - -#include "path_trie.h" - -const double OOV_SCORE = -1000.0; -const std::string START_TOKEN = ""; -const std::string UNK_TOKEN = ""; -const std::string END_TOKEN = ""; - -// Implement a callback to retrive the dictionary of language model. -class RetriveStrEnumerateVocab : public lm::EnumerateVocab { - public: - RetriveStrEnumerateVocab() {} - - void Add(lm::WordIndex index, const StringPiece &str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } - - std::vector vocabulary; -}; - -/* External scorer to query score for n-gram or sentence, including language - * model scoring and word insertion. - * - * Example: - * Scorer scorer(alpha, beta, "path_of_language_model"); - * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); - * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); - */ -class Scorer { - public: - Scorer(double alpha, - double beta, - const std::string &lm_path, - const std::vector &vocabulary); - ~Scorer(); - - double get_log_cond_prob(const std::vector &words); - - double get_sent_log_prob(const std::vector &words); - - // return the max order - size_t get_max_order() const { return max_order_; } - - // return the dictionary size of language model - size_t get_dict_size() const { return dict_size_; } - - // retrun true if the language model is character based - bool is_character_based() const { return is_character_based_; } - - // reset params alpha & beta - void reset_params(float alpha, float beta); - - // make ngram for a given prefix - std::vector make_ngram(PathTrie *prefix); - - // trransform the labels in index to the vector of words (word based lm) or - // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); - - // language model weight - double alpha; - // word insertion weight - double beta; - - // pointer to the dictionary of FST - void *dictionary; - - protected: - // necessary setup: load language model, set char map, fill FST's dictionary - void setup(const std::string &lm_path, - const std::vector &vocab_list); - - // load language model from given path - void load_lm(const std::string &lm_path); - - // fill dictionary for FST - void fill_dictionary(bool add_space); - - // set char map - void set_char_map(const std::vector &char_list); - - double get_log_prob(const std::vector &words); - - // translate the vector in index to string - std::string vec2str(const std::vector &input); - - private: - void *language_model_; - bool is_character_based_; - size_t max_order_; - size_t dict_size_; - - int SPACE_ID_; - std::vector char_list_; - std::unordered_map char_map_; - - std::vector vocabulary_; -}; - -#endif // SCORER_H_ diff --git a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc b/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc deleted file mode 100644 index e0acbe77..00000000 --- a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor, repalce with gtest - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" - -DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test decoder by feeding nnet posterior probability -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader likelihood_reader( - FLAGS_nnet_prob_respecifier); - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - std::shared_ptr decodable( - new ppspeech::Decodable(nullptr, nullptr)); - - decoder.InitDecoder(); - - for (; !likelihood_reader.Done(); likelihood_reader.Next()) { - string utt = likelihood_reader.Key(); - const kaldi::Matrix likelihood = likelihood_reader.Value(); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << likelihood.NumRows(); - LOG(INFO) << "cols: " << likelihood.NumCols(); - decodable->Acceptlikelihood(likelihood); - decoder.AdvanceDecode(decodable); - std::string result; - result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; - decodable->Reset(); - decoder.Reset(); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/param.h b/speechx/speechx/asr/decoder/param.h index ebdd7119..cad6dbd8 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/speechx/speechx/asr/decoder/param.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" +//#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 2846540e..819cc2e8 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -1,30 +1,12 @@ set(srcs decodable.cc nnet_producer.cc) -if(USING_DS2) - list(APPEND srcs ds2_nnet.cc) -endif() - -if(USING_U2) - list(APPEND srcs u2_nnet.cc) -endif() +list(APPEND srcs u2_nnet.cc) add_library(nnet STATIC ${srcs}) target_link_libraries(nnet utils) -if(USING_U2) - target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) - target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -endif() - - -if(USING_DS2) - set(bin_name ds2_nnet_main) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) - - target_link_libraries(${bin_name} ${DEPS}) -endif() +target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) +target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) # test bin #if(USING_U2) diff --git a/speechx/speechx/asr/nnet/ds2_nnet.cc b/speechx/speechx/asr/nnet/ds2_nnet.cc deleted file mode 100644 index f77c0a60..00000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.cc +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "nnet/ds2_nnet.h" - -#include "utils/strings.h" - -namespace ppspeech { - -using kaldi::Matrix; -using kaldi::Vector; -using std::shared_ptr; -using std::string; -using std::vector; - -void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { - std::vector cache_names; - cache_names = StrSplit(opts.cache_names, ","); - std::vector cache_shapes; - cache_shapes = StrSplit(opts.cache_shape, ","); - assert(cache_shapes.size() == cache_names.size()); - - cache_encouts_.clear(); - cache_names_idx_.clear(); - for (size_t i = 0; i < cache_shapes.size(); i++) { - std::vector tmp_shape; - tmp_shape = StrSplit(cache_shapes[i], "-"); - std::vector cur_shape; - std::transform(tmp_shape.begin(), - tmp_shape.end(), - std::back_inserter(cur_shape), - [](const std::string& s) { return atoi(s.c_str()); }); - cache_names_idx_[cache_names[i]] = i; - std::shared_ptr> cache_eout = - std::make_shared>(cur_shape); - cache_encouts_.push_back(cache_eout); - } -} - -PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { - subsampling_rate_ = opts.subsample_rate; - paddle_infer::Config config; - config.SetModel(opts.model_path, opts.param_path); - if (opts.use_gpu) { - config.EnableUseGpu(500, 0); - } - config.SwitchIrOptim(opts.switch_ir_optim); - if (opts.enable_fc_padding == false) { - config.DisableFCPadding(); - } - if (opts.enable_profile) { - config.EnableProfile(); - } - pool.reset( - new paddle_infer::services::PredictorPool(config, opts.thread_num)); - if (pool == nullptr) { - LOG(ERROR) << "create the predictor pool failed"; - } - pool_usages.resize(opts.thread_num); - std::fill(pool_usages.begin(), pool_usages.end(), false); - LOG(INFO) << "load paddle model success"; - - LOG(INFO) << "start to check the predictor input and output names"; - LOG(INFO) << "input names: " << opts.input_names; - LOG(INFO) << "output names: " << opts.output_names; - std::vector input_names_vec = StrSplit(opts.input_names, ","); - std::vector output_names_vec = StrSplit(opts.output_names, ","); - - paddle_infer::Predictor* predictor = GetPredictor(); - - std::vector model_input_names = predictor->GetInputNames(); - assert(input_names_vec.size() == model_input_names.size()); - for (size_t i = 0; i < model_input_names.size(); i++) { - assert(input_names_vec[i] == model_input_names[i]); - } - - std::vector model_output_names = predictor->GetOutputNames(); - assert(output_names_vec.size() == model_output_names.size()); - for (size_t i = 0; i < output_names_vec.size(); i++) { - assert(output_names_vec[i] == model_output_names[i]); - } - - ReleasePredictor(predictor); - InitCacheEncouts(opts); -} - -void PaddleNnet::Reset() { InitCacheEncouts(opts_); } - -paddle_infer::Predictor* PaddleNnet::GetPredictor() { - paddle_infer::Predictor* predictor = nullptr; - - std::lock_guard guard(pool_mutex); - int pred_id = 0; - - while (pred_id < pool_usages.size()) { - if (pool_usages[pred_id] == false) { - predictor = pool->Retrive(pred_id); - break; - } - ++pred_id; - } - - if (predictor) { - pool_usages[pred_id] = true; - predictor_to_thread_id[predictor] = pred_id; - } else { - LOG(INFO) << "Failed to get predictor from pool !!!"; - } - - return predictor; -} - -int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { - std::lock_guard guard(pool_mutex); - auto iter = predictor_to_thread_id.find(predictor); - - if (iter == predictor_to_thread_id.end()) { - LOG(INFO) << "there is no such predictor"; - return 0; - } - - pool_usages[iter->second] = false; - predictor_to_thread_id.erase(predictor); - return 0; -} - -shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { - auto iter = cache_names_idx_.find(name); - if (iter == cache_names_idx_.end()) { - return nullptr; - } - assert(iter->second < cache_encouts_.size()); - return cache_encouts_[iter->second]; -} - -void PaddleNnet::FeedForward(const Vector& features, - const int32& feature_dim, - NnetOut* out) { - paddle_infer::Predictor* predictor = GetPredictor(); - - int feat_row = features.Dim() / feature_dim; - - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - - // feed inputs - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, feat_row, feature_dim}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(features.Data()); - - std::unique_ptr input_len = - predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(feat_row); - input_len->CopyFromCpu(audio_len.data()); - - std::unique_ptr state_h = - predictor->GetInputHandle(input_names[2]); - shared_ptr> h_cache = GetCacheEncoder(input_names[2]); - state_h->Reshape(h_cache->get_shape()); - state_h->CopyFromCpu(h_cache->get_data().data()); - - std::unique_ptr state_c = - predictor->GetInputHandle(input_names[3]); - shared_ptr> c_cache = GetCacheEncoder(input_names[3]); - state_c->Reshape(c_cache->get_shape()); - state_c->CopyFromCpu(c_cache->get_data().data()); - - // forward - bool success = predictor->Run(); - - if (success == false) { - LOG(INFO) << "predictor run occurs error"; - } - - // fetch outpus - std::unique_ptr h_out = - predictor->GetOutputHandle(output_names[2]); - assert(h_cache->get_shape() == h_out->shape()); - h_out->CopyToCpu(h_cache->get_data().data()); - - std::unique_ptr c_out = - predictor->GetOutputHandle(output_names[3]); - assert(c_cache->get_shape() == c_out->shape()); - c_out->CopyToCpu(c_cache->get_data().data()); - - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - int32 row = output_shape[1]; - int32 col = output_shape[2]; - - - // inferences->Resize(row * col); - // *inference_dim = col; - out->logprobs.Resize(row * col); - out->vocab_dim = col; - output_tensor->CopyToCpu(out->logprobs.Data()); - - ReleasePredictor(predictor); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/nnet/ds2_nnet.h b/speechx/speechx/asr/nnet/ds2_nnet.h deleted file mode 100644 index 420fa177..00000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" -#include "nnet/nnet_itf.h" -#include "paddle_inference_api.h" - -namespace ppspeech { - - -template -class Tensor { - public: - Tensor() {} - explicit Tensor(const std::vector& shape) : _shape(shape) { - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - LOG(INFO) << "Tensor neml: " << neml; - _data.resize(neml, 0); - } - - void reshape(const std::vector& shape) { - _shape = shape; - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - _data.resize(neml, 0); - } - - const std::vector& get_shape() const { return _shape; } - std::vector& get_data() { return _data; } - - private: - std::vector _shape; - std::vector _data; -}; - -class PaddleNnet : public NnetBase { - public: - explicit PaddleNnet(const ModelOptions& opts); - - void FeedForward(const kaldi::Vector& features, - const int32& feature_dim, - NnetOut* out) override; - - void AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score) override { - VLOG(2) << "deepspeech2 not has AttentionRescoring."; - } - - void Dim(); - - void Reset() override; - - bool IsLogProb() override { return false; } - - - std::shared_ptr> GetCacheEncoder( - const std::string& name); - - void InitCacheEncouts(const ModelOptions& opts); - - void EncoderOuts(std::vector>* encoder_out) - const override {} - - private: - paddle_infer::Predictor* GetPredictor(); - int ReleasePredictor(paddle_infer::Predictor* predictor); - - std::unique_ptr pool; - std::vector pool_usages; - std::mutex pool_mutex; - std::map predictor_to_thread_id; - std::map cache_names_idx_; - std::vector>> cache_encouts_; - - ModelOptions opts_; - - public: - DISALLOW_COPY_AND_ASSIGN(PaddleNnet); -}; - -} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/ds2_nnet_main.cc b/speechx/speechx/asr/nnet/ds2_nnet_main.cc deleted file mode 100644 index 6092b8a4..00000000 --- a/speechx/speechx/asr/nnet/ds2_nnet_main.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "base/common.h" -#include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier); - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - LOG(INFO) << "model path: " << model_graph; - LOG(INFO) << "model param: " << model_params; - - int32 num_done = 0, num_err = 0; - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - int32 frame_idx = 0; - std::vector> prob_vec; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - vector prob; - while (decodable->FrameLikelihood(frame_idx, &prob)) { - kaldi::Vector vec_tmp(prob.size()); - std::memcpy(vec_tmp.Data(), - prob.data(), - sizeof(kaldi::BaseFloat) * prob.size()); - prob_vec.push_back(vec_tmp); - frame_idx++; - } - } - decodable->Reset(); - if (prob_vec.size() == 0) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the nnet prob of " << utt << " is empty"; - continue; - } - kaldi::Matrix result(prob_vec.size(), - prob_vec[0].Dim()); - for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { - for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { - result(row_idx, col_idx) = prob_vec[row_idx](col_idx); - } - } - - nnet_writer.Write(utt, result); - ++num_done; - } - - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 3a0c4f18..95507591 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -65,7 +65,6 @@ bool NnetProducer::Compute() { size_t nframes = logprobs.Dim() / vocab_dim; VLOG(2) << "Forward out " << nframes << " decoder frames."; std::vector logprob(vocab_dim); - // remove later. for (size_t idx = 0; idx < nframes; ++idx) { for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 53e2e58d..6d8db93c 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -1,46 +1,22 @@ set(srcs) -if (USING_DS2) list(APPEND srcs -recognizer.cc + u2_recognizer.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - u2_recognizer.cc - ) -endif() add_library(recognizer STATIC ${srcs}) target_link_libraries(recognizer PUBLIC decoder) -# test -if (USING_DS2) - set(BINS recognizer_main) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - u2_recognizer_main - u2_recognizer_thread_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() - -endif() +set(TEST_BINS + u2_recognizer_main + u2_recognizer_thread_main +) +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer.cc b/speechx/speechx/asr/recognizer/recognizer.cc deleted file mode 100644 index c6631813..00000000 --- a/speechx/speechx/asr/recognizer/recognizer.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "recognizer/recognizer.h" - - -namespace ppspeech { - -using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; -using std::unique_ptr; -using std::vector; - - -Recognizer::Recognizer(const RecognizerResource& resource) { - // resource_ = resource; - const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - - std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); - - BaseFloat ac_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); - - decoder_.reset(new TLGDecoder(resource.tlg_opts)); - - input_finished_ = false; -} - -void Recognizer::Accept(const Vector& waves) { - feature_pipeline_->Accept(waves); -} - -void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } - -std::string Recognizer::GetFinalResult() { - return decoder_->GetFinalBestPath(); -} - -std::string Recognizer::GetPartialResult() { - return decoder_->GetPartialResult(); -} - -void Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); - input_finished_ = true; -} - -bool Recognizer::IsFinished() { return input_finished_; } - -void Recognizer::Reset() { - feature_pipeline_->Reset(); - decodable_->Reset(); - decoder_->Reset(); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer.h b/speechx/speechx/asr/recognizer/recognizer.h deleted file mode 100644 index 57d5bb36..00000000 --- a/speechx/speechx/asr/recognizer/recognizer.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor later (SGoat) - -#pragma once - -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" -#include "frontend/audio/feature_pipeline.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DECLARE_double(acoustic_scale); - -namespace ppspeech { - -struct RecognizerResource { - kaldi::BaseFloat acoustic_scale{1.0}; - FeaturePipelineOptions feature_pipeline_opts{}; - ModelOptions model_opts{}; - TLGDecoderOptions tlg_opts{}; - // CTCBeamSearchOptions beam_search_opts; - - static RecognizerResource InitFromFlags() { - RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = - FeaturePipelineOptions::InitFromFlags(); - resource.feature_pipeline_opts.assembler_opts.fill_zero = true; - LOG(INFO) << "ds2 need fill zero be true: " - << resource.feature_pipeline_opts.assembler_opts.fill_zero; - resource.model_opts = ModelOptions::InitFromFlags(); - resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); - return resource; - } -}; - -class Recognizer { - public: - explicit Recognizer(const RecognizerResource& resouce); - void Accept(const kaldi::Vector& waves); - void Decode(); - std::string GetFinalResult(); - std::string GetPartialResult(); - void SetFinished(); - bool IsFinished(); - void Reset(); - - private: - // std::shared_ptr resource_; - // RecognizerResource resource_; - std::shared_ptr feature_pipeline_; - std::shared_ptr decodable_; - std::unique_ptr decoder_; - bool input_finished_; -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer_main.cc b/speechx/speechx/asr/recognizer/recognizer_main.cc deleted file mode 100644 index cb0de2d6..00000000 --- a/speechx/speechx/asr/recognizer/recognizer_main.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/table-types.h" -#include "recognizer/recognizer.h" - -DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); -DEFINE_int32(sample_rate, 16000, "sample rate"); - - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - ppspeech::RecognizerResource resource = - ppspeech::RecognizerResource::InitFromFlags(); - ppspeech::Recognizer recognizer(resource); - - kaldi::SequentialTableReader wav_reader( - FLAGS_wav_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - int sample_rate = FLAGS_sample_rate; - float streaming_chunk = FLAGS_streaming_chunk; - int chunk_sample_size = streaming_chunk * sample_rate; - LOG(INFO) << "sr: " << sample_rate; - LOG(INFO) << "chunk size (s): " << streaming_chunk; - LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - - int32 num_done = 0, num_err = 0; - double tot_wav_duration = 0.0; - - kaldi::Timer timer; - - for (; !wav_reader.Done(); wav_reader.Next()) { - std::string utt = wav_reader.Key(); - const kaldi::WaveData& wave_data = wav_reader.Value(); - - int32 this_channel = 0; - kaldi::SubVector waveform(wave_data.Data(), - this_channel); - int tot_samples = waveform.Dim(); - tot_wav_duration += tot_samples * 1.0 / sample_rate; - LOG(INFO) << "wav len (sample): " << tot_samples; - - int sample_offset = 0; - std::vector> feats; - int feature_rows = 0; - while (sample_offset < tot_samples) { - int cur_chunk_size = - std::min(chunk_sample_size, tot_samples - sample_offset); - - kaldi::Vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); - } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); - } - recognizer.Decode(); - - // no overlap - sample_offset += cur_chunk_size; - } - - std::string result; - result = recognizer.GetFinalResult(); - recognizer.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - double elapsed = timer.Elapsed(); - KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done); - KALDI_LOG << " cost:" << elapsed << " s"; - KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s"; - KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration; -} diff --git a/speechx/speechx/codelab/CMakeLists.txt b/speechx/speechx/codelab/CMakeLists.txt index 95043263..c8445fb8 100644 --- a/speechx/speechx/codelab/CMakeLists.txt +++ b/speechx/speechx/codelab/CMakeLists.txt @@ -1,4 +1,3 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(glog) -add_subdirectory(nnet) diff --git a/speechx/speechx/codelab/nnet/CMakeLists.txt b/speechx/speechx/codelab/nnet/CMakeLists.txt deleted file mode 100644 index dcad8a9c..00000000 --- a/speechx/speechx/codelab/nnet/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - -set(bin_name ds2_model_test_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} PUBLIC nnet gflags glog ${DEPS}) diff --git a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc b/speechx/speechx/codelab/nnet/ds2_model_test_main.cc deleted file mode 100644 index ab7b2cb5..00000000 --- a/speechx/speechx/codelab/nnet/ds2_model_test_main.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// deepspeech2 online model info - -#include -#include -#include -#include -#include -#include -#include - -#include "base/flags.h" -#include "base/log.h" -#include "paddle_inference_api.h" - -using std::cout; -using std::endl; - - -DEFINE_string(model_path, "", "xxx.pdmodel"); -DEFINE_string(param_path, "", "xxx.pdiparams"); -DEFINE_int32(chunk_size, 35, "feature chunk size, unit:frame"); -DEFINE_int32(feat_dim, 161, "feature dim"); - - -void produce_data(std::vector>* data); -void model_forward_test(); - -void produce_data(std::vector>* data) { - int chunk_size = FLAGS_chunk_size; // chunk_size in frame - int col_size = FLAGS_feat_dim; // feat dim - cout << "chunk size: " << chunk_size << endl; - cout << "feat dim: " << col_size << endl; - - data->reserve(chunk_size); - data->back().reserve(col_size); - for (int row = 0; row < chunk_size; ++row) { - data->push_back(std::vector()); - for (int col_idx = 0; col_idx < col_size; ++col_idx) { - data->back().push_back(0.201); - } - } -} - -void model_forward_test() { - std::cout << "1. read the data" << std::endl; - std::vector> feats; - produce_data(&feats); - - std::cout << "2. load the model" << std::endl; - ; - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - CHECK_NE(model_graph, ""); - CHECK_NE(model_params, ""); - cout << "model path: " << model_graph << endl; - cout << "model param path : " << model_params << endl; - - paddle_infer::Config config; - config.SetModel(model_graph, model_params); - config.SwitchIrOptim(false); - cout << "SwitchIrOptim: " << false << endl; - config.DisableFCPadding(); - cout << "DisableFCPadding: " << endl; - auto predictor = paddle_infer::CreatePredictor(config); - - std::cout << "3. feat shape, row=" << feats.size() - << ",col=" << feats[0].size() << std::endl; - std::vector pp_input_mat; - for (const auto& item : feats) { - pp_input_mat.insert(pp_input_mat.end(), item.begin(), item.end()); - } - - std::cout << "4. fead the data to model" << std::endl; - int row = feats.size(); - int col = feats[0].size(); - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - for (auto name : input_names) { - cout << "model input names: " << name << endl; - } - for (auto name : output_names) { - cout << "model output names: " << name << endl; - } - - // input - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(pp_input_mat.data()); - - // input length - std::unique_ptr input_len = - predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(row); - input_len->CopyFromCpu(audio_len.data()); - - // state_h - std::unique_ptr chunk_state_h_box = - predictor->GetInputHandle(input_names[2]); - std::vector chunk_state_h_box_shape = {5, 1, 1024}; - chunk_state_h_box->Reshape(chunk_state_h_box_shape); - int chunk_state_h_box_size = - std::accumulate(chunk_state_h_box_shape.begin(), - chunk_state_h_box_shape.end(), - 1, - std::multiplies()); - std::vector chunk_state_h_box_data(chunk_state_h_box_size, 0.0f); - chunk_state_h_box->CopyFromCpu(chunk_state_h_box_data.data()); - - // state_c - std::unique_ptr chunk_state_c_box = - predictor->GetInputHandle(input_names[3]); - std::vector chunk_state_c_box_shape = {5, 1, 1024}; - chunk_state_c_box->Reshape(chunk_state_c_box_shape); - int chunk_state_c_box_size = - std::accumulate(chunk_state_c_box_shape.begin(), - chunk_state_c_box_shape.end(), - 1, - std::multiplies()); - std::vector chunk_state_c_box_data(chunk_state_c_box_size, 0.0f); - chunk_state_c_box->CopyFromCpu(chunk_state_c_box_data.data()); - - // run - bool success = predictor->Run(); - - // state_h out - std::unique_ptr h_out = - predictor->GetOutputHandle(output_names[2]); - std::vector h_out_shape = h_out->shape(); - int h_out_size = std::accumulate( - h_out_shape.begin(), h_out_shape.end(), 1, std::multiplies()); - std::vector h_out_data(h_out_size); - h_out->CopyToCpu(h_out_data.data()); - - // stage_c out - std::unique_ptr c_out = - predictor->GetOutputHandle(output_names[3]); - std::vector c_out_shape = c_out->shape(); - int c_out_size = std::accumulate( - c_out_shape.begin(), c_out_shape.end(), 1, std::multiplies()); - std::vector c_out_data(c_out_size); - c_out->CopyToCpu(c_out_data.data()); - - // output tensor - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - std::vector output_probs; - int output_size = std::accumulate( - output_shape.begin(), output_shape.end(), 1, std::multiplies()); - output_probs.resize(output_size); - output_tensor->CopyToCpu(output_probs.data()); - row = output_shape[1]; - col = output_shape[2]; - - // probs - std::vector> probs; - probs.reserve(row); - for (int i = 0; i < row; i++) { - probs.push_back(std::vector()); - probs.back().reserve(col); - - for (int j = 0; j < col; j++) { - probs.back().push_back(output_probs[i * col + j]); - } - } - - std::vector> log_feat = probs; - std::cout << "probs, row: " << log_feat.size() - << " col: " << log_feat[0].size() << std::endl; - for (size_t row_idx = 0; row_idx < log_feat.size(); ++row_idx) { - for (size_t col_idx = 0; col_idx < log_feat[row_idx].size(); - ++col_idx) { - std::cout << log_feat[row_idx][col_idx] << " "; - } - std::cout << std::endl; - } -} - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - model_forward_test(); - return 0; -} diff --git a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc index 713c9ef1..8c65b346 100644 --- a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc +++ b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc @@ -20,15 +20,12 @@ #include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/kaldi-io.h" #include "utils/file_utils.h" -// #include "boost/json.hpp" -#include +#include "utils/picojson.h" DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); -using namespace boost::json; // from - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,36 +37,49 @@ int main(int argc, char* argv[]) { auto ifs = std::ifstream(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file); - auto value = boost::json::parse(json_str); - if (!value.is_object()) { + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { LOG(ERROR) << "Input json file format error."; } - for (auto obj : value.as_object()) { - if (obj.key() == "mean_stat") { - VLOG(2) << "mean_stat:" << obj.value(); + const picojson::value::object& obj = value.get(); + for (picojson::value::object::const_iterator elem = obj.begin(); + elem != obj.end(); + ++elem) { + if (elem->first == "mean_stat") { + VLOG(2) << "mean_stat:" << elem->second; + // const picojson::value tmp = + // elem->second.get(0);//(); + double tmp = + elem->second.get(0).get(); //(); + VLOG(2) << "tmp: " << tmp; } - if (obj.key() == "var_stat") { - VLOG(2) << "var_stat: " << obj.value(); + if (elem->first == "var_stat") { + VLOG(2) << "var_stat: " << elem->second; } - if (obj.key() == "frame_num") { - VLOG(2) << "frame_num: " << obj.value(); + if (elem->first == "frame_num") { + VLOG(2) << "frame_num: " << elem->second; } } - boost::json::array mean_stat = value.at("mean_stat").as_array(); + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); std::vector mean_stat_vec; for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { - mean_stat_vec.push_back(it->as_double()); + mean_stat_vec.push_back((*it).get()); } - boost::json::array var_stat = value.at("var_stat").as_array(); + const picojson::value::array& var_stat = + value.get("var_stat").get(); std::vector var_stat_vec; for (auto it = var_stat.begin(); it != var_stat.end(); it++) { - var_stat_vec.push_back(it->as_double()); + var_stat_vec.push_back((*it).get()); } - kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64()); + kaldi::int32 frame_num = value.get("frame_num").get(); LOG(INFO) << "nframe: " << frame_num; size_t mean_size = mean_stat_vec.size(); diff --git a/speechx/speechx/common/utils/picojson.h b/speechx/speechx/common/utils/picojson.h new file mode 100644 index 00000000..28c5b7fa --- /dev/null +++ b/speechx/speechx/common/utils/picojson.h @@ -0,0 +1,1202 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) \ + throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { +public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + +protected: + int type_; + _storage u_; + +public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template bool is() const; + template const T &get() const; + template T &get(); + template void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + +private: + template value(const T *); // intentionally defined to block implicit conversion of pointer to bool + template static void _indent(Iter os, int indent); + template void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() { +} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { + u_.boolean_ = b; +} + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { + u_.int64_ = i; +} +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { + clear(); +} + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } \ + template <> inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && (const_cast(this)->type_ = number_type, (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); i != u_.array_->end(); ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); i != u_.object_->end(); ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template class input { +protected: + Iter cur_, end_; + bool consumed_; + int line_; + +public: + input(const Iter &first, const Iter &last) : cur_(first), end_(last), consumed_(false), line_(1) { + } + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { + consumed_ = false; + } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { + return line_; + } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); pi != pattern.end(); ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back(static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && std::numeric_limits::min() <= ival && ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { +public: + bool set_null() { + return false; + } + bool set_bool(bool) { + return false; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return false; + } +#endif + bool set_number(double) { + return false; + } + template bool parse_string(input &) { + return false; + } + bool parse_array_start() { + return false; + } + template bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { + return false; + } + bool parse_object_start() { + return false; + } + template bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { +protected: + value *out_; + size_t depths_; + +public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) : out_(out), depths_(depths) { + } + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + *out_ = value(object_type, false); + return true; + } + template bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + +private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { +protected: + size_t depths_; + +public: + struct dummy_str { + void push_back(int) { + } + }; + +public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) { + } + bool set_null() { + return true; + } + bool set_bool(bool) { + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return true; + } +#endif + bool set_number(double) { + return true; + } + template bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { + return true; + } + +private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template inline Iter _parse(Context &ctx, const Iter &first, const Iter &last, std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template inline Iter parse(value &out, const Iter &first, const Iter &last, std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, std::istreambuf_iterator(is.rdbuf()), std::istreambuf_iterator(), &err); + return err; +} + +template struct last_error_t { static std::string s; }; +template std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { + last_error_t::s = s; +} + +inline const std::string &get_last_error() { + return last_error_t::s; +} + +inline bool operator==(const value &x, const value &y) { + if (x.is()) + return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) \ + return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { + return !(x == y); +} +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file From c1b1ae0515e4ad1216e2378366b11f7a08abee66 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 4 Jan 2023 16:52:19 +0800 Subject: [PATCH 05/50] [speechx]add kaldi-native-fbank && refactor frontend (#2794) * replace kaldi-fbank with kaldi-native-fbank * make kaldi-native-fbank work --- .../u2pp_ol/wenetspeech/local/feat.sh | 11 +- .../u2pp_ol/wenetspeech/local/recognizer.sh | 2 +- .../wenetspeech/local/recognizer_quant.sh | 2 +- speechx/examples/u2pp_ol/wenetspeech/run.sh | 3 +- .../ctc_prefix_beam_search_decoder_main.cc | 13 +- speechx/speechx/asr/nnet/nnet_itf.h | 6 +- speechx/speechx/asr/nnet/nnet_producer.cc | 19 +- speechx/speechx/asr/nnet/nnet_producer.h | 2 +- speechx/speechx/asr/nnet/u2_nnet.cc | 19 +- speechx/speechx/asr/nnet/u2_nnet.h | 4 +- speechx/speechx/asr/recognizer/CMakeLists.txt | 4 +- .../speechx/asr/recognizer/u2_recognizer.cc | 7 +- .../speechx/asr/recognizer/u2_recognizer.h | 2 +- .../asr/recognizer/u2_recognizer_main.cc | 4 +- .../recognizer/u2_recognizer_thread_main.cc | 4 +- speechx/speechx/common/CMakeLists.txt | 6 - .../common/frontend/audio/CMakeLists.txt | 24 +- .../common/frontend/audio/assembler.cc | 33 +- .../speechx/common/frontend/audio/assembler.h | 8 +- .../common/frontend/audio/audio_cache.cc | 25 +- .../common/frontend/audio/audio_cache.h | 4 +- speechx/speechx/common/frontend/audio/cmvn.cc | 111 +- speechx/speechx/common/frontend/audio/cmvn.h | 11 +- .../frontend/audio/cmvn_json2kaldi_main.cc | 98 - .../frontend/audio/compute_fbank_main.cc | 14 +- .../common/frontend/audio/data_cache.h | 16 +- speechx/speechx/common/frontend/audio/fbank.h | 29 +- .../common/frontend/audio/feature-fbank.cc | 123 + .../common/frontend/audio/feature-fbank.h | 137 + .../frontend/audio/feature-functions.cc | 49 + .../common/frontend/audio/feature-functions.h | 38 + .../common/frontend/audio/feature-window.cc | 247 ++ .../common/frontend/audio/feature-window.h | 183 + .../common/frontend/audio/feature_cache.cc | 21 +- .../common/frontend/audio/feature_cache.h | 10 +- .../common/frontend/audio/feature_common.h | 16 +- .../frontend/audio/feature_common_inl.h | 82 +- .../common/frontend/audio/feature_pipeline.cc | 11 +- .../common/frontend/audio/feature_pipeline.h | 38 +- speechx/speechx/common/frontend/audio/fftsg.c | 3271 +++++++++++++++++ .../common/frontend/audio/frontend_itf.h | 4 +- .../common/frontend/audio/mel-computations.cc | 277 ++ .../common/frontend/audio/mel-computations.h | 120 + speechx/speechx/common/frontend/audio/rfft.cc | 66 + speechx/speechx/common/frontend/audio/rfft.h | 56 + 45 files changed, 4824 insertions(+), 406 deletions(-) delete mode 100644 speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc create mode 100644 speechx/speechx/common/frontend/audio/feature-fbank.cc create mode 100644 speechx/speechx/common/frontend/audio/feature-fbank.h create mode 100644 speechx/speechx/common/frontend/audio/feature-functions.cc create mode 100644 speechx/speechx/common/frontend/audio/feature-functions.h create mode 100644 speechx/speechx/common/frontend/audio/feature-window.cc create mode 100644 speechx/speechx/common/frontend/audio/feature-window.h create mode 100644 speechx/speechx/common/frontend/audio/fftsg.c create mode 100644 speechx/speechx/common/frontend/audio/mel-computations.cc create mode 100644 speechx/speechx/common/frontend/audio/mel-computations.h create mode 100644 speechx/speechx/common/frontend/audio/rfft.cc create mode 100644 speechx/speechx/common/frontend/audio/rfft.h diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh index e181951e..8221611c 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh @@ -19,21 +19,12 @@ aishell_wav_scp=aishell_test.scp if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - cmvn_json2kaldi_main \ - --json_file $model_dir/mean_std.json \ - --cmvn_write_path $exp/cmvn.ark \ - --binary=false - - echo "convert json cmvn to kaldi ark." -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \ compute_fbank_main \ --num_bins 80 \ - --cmvn_file=$exp/cmvn.ark \ + --cmvn_file=$model_dir/mean_std.json \ --streaming_chunk=36 \ --wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \ --feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh index 344fbcbc..fd66e60c 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh @@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \ u2_recognizer_main \ --use_fbank=true \ --num_bins=80 \ - --cmvn_file=$exp/cmvn.ark \ + --cmvn_file=$model_dir/mean_std.json \ --model_path=$model_dir/export.jit \ --vocab_path=$model_dir/unit.txt \ --nnet_decoder_chunk=16 \ diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh index 1ce403a3..555feb83 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh @@ -19,7 +19,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.quant.log \ u2_recognizer_main \ --use_fbank=true \ --num_bins=80 \ - --cmvn_file=$exp/cmvn.ark \ + --cmvn_file=$model_dir/mean_std.json \ --model_path=$model_dir/export \ --vocab_path=$model_dir/unit.txt \ --nnet_decoder_chunk=16 \ diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/speechx/examples/u2pp_ol/wenetspeech/run.sh index 4bbf7920..002bd304 100755 --- a/speechx/examples/u2pp_ol/wenetspeech/run.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/run.sh @@ -22,7 +22,6 @@ if [ ! -d ${SPEECHX_BUILD} ]; then popd fi - ckpt_dir=$data/model if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then @@ -72,7 +71,7 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # process cmvn and compute fbank feat + # process compute fbank feat ./local/feat.sh fi diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index 31276895..b42ca69b 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "base/common.h" #include "decoder/ctc_prefix_beam_search_decoder.h" +#include "base/common.h" #include "frontend/audio/data_cache.h" #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" @@ -124,15 +124,14 @@ int main(int argc, char* argv[]) { } - kaldi::Vector feature_chunk(this_chunk_size * - feat_dim); + std::vector feature_chunk(this_chunk_size * + feat_dim); int32 start = chunk_idx * chunk_stride; for (int row_id = 0; row_id < this_chunk_size; ++row_id) { kaldi::SubVector feat_row(feature, start); - kaldi::SubVector feature_chunk_row( - feature_chunk.Data() + row_id * feat_dim, feat_dim); - - feature_chunk_row.CopyFromVec(feat_row); + std::memcpy(feature_chunk.data() + row_id * feat_dim, + feat_row.Data(), + feat_dim * sizeof(kaldi::BaseFloat)); ++start; } diff --git a/speechx/speechx/asr/nnet/nnet_itf.h b/speechx/speechx/asr/nnet/nnet_itf.h index a504cce5..91d7f231 100644 --- a/speechx/speechx/asr/nnet/nnet_itf.h +++ b/speechx/speechx/asr/nnet/nnet_itf.h @@ -71,7 +71,7 @@ struct ModelOptions { struct NnetOut { // nnet out. maybe logprob or prob. Almost time this is logprob. - kaldi::Vector logprobs; + std::vector logprobs; int32 vocab_dim; // nnet state. Only using in Attention model. @@ -89,7 +89,7 @@ class NnetInterface { // nnet do not cache feats, feats cached by frontend. // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache, // frame_offset. - virtual void FeedForward(const kaldi::Vector& features, + virtual void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) = 0; @@ -105,7 +105,7 @@ class NnetInterface { // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( - std::vector>* encoder_out) const = 0; + std::vector>* encoder_out) const = 0; }; diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 95507591..886c14d0 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -17,13 +17,14 @@ namespace ppspeech { using kaldi::Vector; +using std::vector; using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) : nnet_(nnet), frontend_(frontend) {} -void NnetProducer::Accept(const kaldi::VectorBase& inputs) { +void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); bool result = false; do { @@ -49,26 +50,24 @@ bool NnetProducer::Read(std::vector* nnet_prob) { } bool NnetProducer::Compute() { - Vector features; + vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. VLOG(3) << "no feat avalible"; return false; } CHECK_GE(frontend_->Dim(), 0); - VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; + VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats."; NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; - Vector& logprobs = out.logprobs; - size_t nframes = logprobs.Dim() / vocab_dim; + size_t nframes = out.logprobs.size() / vocab_dim; VLOG(2) << "Forward out " << nframes << " decoder frames."; - std::vector logprob(vocab_dim); for (size_t idx = 0; idx < nframes; ++idx) { - for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { - logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); - } + std::vector logprob( + out.logprobs.data() + idx * vocab_dim, + out.logprobs.data() + (idx + 1) * vocab_dim); cache_.push_back(logprob); } return true; @@ -80,4 +79,4 @@ void NnetProducer::AttentionRescoring(const std::vector>& hyps, nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index 65e9116f..953943cc 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -27,7 +27,7 @@ class NnetProducer { std::shared_ptr frontend = NULL); // Feed feats or waves - void Accept(const kaldi::VectorBase& inputs); + void Accept(const std::vector& inputs); void Acceptlikelihood(const kaldi::Matrix& likelihood); diff --git a/speechx/speechx/asr/nnet/u2_nnet.cc b/speechx/speechx/asr/nnet/u2_nnet.cc index 7707406a..e3277a38 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.cc +++ b/speechx/speechx/asr/nnet/u2_nnet.cc @@ -165,23 +165,16 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) { } -void U2Nnet::FeedForward(const kaldi::Vector& features, +void U2Nnet::FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) { kaldi::Timer timer; - std::vector chunk_feats(features.Data(), - features.Data() + features.Dim()); std::vector ctc_probs; ForwardEncoderChunkImpl( - chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim); - - out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero); - std::memcpy(out->logprobs.Data(), - ctc_probs.data(), - ctc_probs.size() * sizeof(kaldi::BaseFloat)); + features, feature_dim, &out->logprobs, &out->vocab_dim); VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. " - << chunk_feats.size() / feature_dim << " frames."; + << features.size() / feature_dim << " frames."; } @@ -638,7 +631,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, void U2Nnet::EncoderOuts( - std::vector>* encoder_out) const { + std::vector>* encoder_out) const { // list of (B=1,T,D) int size = encoder_outs_.size(); VLOG(3) << "encoder_outs_ size: " << size; @@ -657,8 +650,8 @@ void U2Nnet::EncoderOuts( const float* this_tensor_ptr = item.data(); for (int j = 0; j < T; j++) { const float* cur = this_tensor_ptr + j * D; - kaldi::Vector out(D); - std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); + std::vector out(D); + std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat)); encoder_out->emplace_back(out); } } diff --git a/speechx/speechx/asr/nnet/u2_nnet.h b/speechx/speechx/asr/nnet/u2_nnet.h index 23cc0ea3..f7b703f6 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.h +++ b/speechx/speechx/asr/nnet/u2_nnet.h @@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase { explicit U2Nnet(const ModelOptions& opts); U2Nnet(const U2Nnet& other); - void FeedForward(const kaldi::Vector& features, + void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) override; @@ -111,7 +111,7 @@ class U2Nnet : public U2NnetBase { void FeedEncoderOuts(const paddle::Tensor& encoder_out); void EncoderOuts( - std::vector>* encoder_out) const; + std::vector>* encoder_out) const; private: ModelOptions opts_; diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 6d8db93c..17ba018f 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -15,8 +15,8 @@ set(TEST_BINS foreach(bin_name IN LISTS TEST_BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-feat-common) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -endforeach() \ No newline at end of file +endforeach() diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index ea62ae1a..a7644430 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -19,9 +19,6 @@ namespace ppspeech { using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; using std::vector; @@ -67,10 +64,10 @@ void U2Recognizer::ResetContinuousDecoding() { } -void U2Recognizer::Accept(const VectorBase& waves) { +void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; nnet_producer_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() + VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size() << " samples."; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 855d161a..c92e0b6a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -115,7 +115,7 @@ class U2Recognizer { void Reset(); void ResetContinuousDecoding(); - void Accept(const kaldi::VectorBase& waves); + void Accept(const std::vector& waves); void Decode(); void Rescoring(); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index d7c58407..3e64011c 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -71,9 +71,9 @@ int main(int argc, char* argv[]) { int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index e73efef1..bb72b3b6 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -81,9 +81,9 @@ int main(int argc, char* argv[]) { int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); diff --git a/speechx/speechx/common/CMakeLists.txt b/speechx/speechx/common/CMakeLists.txt index dea9eb05..00426cb5 100644 --- a/speechx/speechx/common/CMakeLists.txt +++ b/speechx/speechx/common/CMakeLists.txt @@ -1,16 +1,10 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} -${CMAKE_CURRENT_SOURCE_DIR}/base -) - -include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/../ -${CMAKE_CURRENT_SOURCE_DIR}/utils ) add_subdirectory(utils) include_directories( -${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/frontend ) add_subdirectory(frontend) diff --git a/speechx/speechx/common/frontend/audio/CMakeLists.txt b/speechx/speechx/common/frontend/audio/CMakeLists.txt index 050d78be..d5396ab2 100644 --- a/speechx/speechx/common/frontend/audio/CMakeLists.txt +++ b/speechx/speechx/common/frontend/audio/CMakeLists.txt @@ -1,29 +1,27 @@ +add_library(kaldi-native-fbank-core + feature-fbank.cc + feature-functions.cc + feature-window.cc + fftsg.c + mel-computations.cc + rfft.cc +) + add_library(frontend STATIC cmvn.cc - db_norm.cc - linear_spectrogram.cc audio_cache.cc feature_cache.cc feature_pipeline.cc - fbank.cc assembler.cc ) -target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common kaldi-fbank) - - - -set(bin_name cmvn_json2kaldi_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog) +target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils) set(BINS - compute_linear_spectrogram_main compute_fbank_main ) foreach(bin_name IN LISTS BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog) + target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog kaldi-feat-common) endforeach() diff --git a/speechx/speechx/common/frontend/audio/assembler.cc b/speechx/speechx/common/frontend/audio/assembler.cc index 9d5fc403..30a650d3 100644 --- a/speechx/speechx/common/frontend/audio/assembler.cc +++ b/speechx/speechx/common/frontend/audio/assembler.cc @@ -17,8 +17,8 @@ namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; +using std::vector; +using std::vector; using std::unique_ptr; Assembler::Assembler(AssemblerOptions opts, @@ -33,13 +33,13 @@ Assembler::Assembler(AssemblerOptions opts, dim_ = base_extractor_->Dim(); } -void Assembler::Accept(const kaldi::VectorBase& inputs) { +void Assembler::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); } // pop feature chunk -bool Assembler::Read(kaldi::Vector* feats) { +bool Assembler::Read(std::vector* feats) { kaldi::Timer timer; bool result = Compute(feats); VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec."; @@ -47,14 +47,14 @@ bool Assembler::Read(kaldi::Vector* feats) { } // read frame by frame from base_feature_extractor_ into cache_ -bool Assembler::Compute(Vector* feats) { +bool Assembler::Compute(vector* feats) { // compute and feed frame by frame while (feature_cache_.size() < frame_chunk_size_) { - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) { + if (result == false || feature.size() == 0) { VLOG(3) << "result: " << result - << " feature dim: " << feature.Dim(); + << " feature dim: " << feature.size(); if (IsFinished() == false) { VLOG(3) << "finished reading feature. cache size: " << feature_cache_.size(); @@ -65,7 +65,7 @@ bool Assembler::Compute(Vector* feats) { } } - CHECK(feature.Dim() == dim_); + CHECK(feature.size() == dim_); feature_cache_.push(feature); nframes_ += 1; @@ -73,14 +73,14 @@ bool Assembler::Compute(Vector* feats) { } if (feature_cache_.size() < receptive_filed_length_) { - VLOG(3) << "feature_cache less than receptive_filed_lenght. " + VLOG(3) << "feature_cache less than receptive_filed_length. " << feature_cache_.size() << ": " << receptive_filed_length_; return false; } if (fill_zero_) { while (feature_cache_.size() < frame_chunk_size_) { - Vector feature(dim_, kaldi::kSetZero); + vector feature(dim_, kaldi::kSetZero); nframes_ += 1; feature_cache_.push(feature); } @@ -88,16 +88,17 @@ bool Assembler::Compute(Vector* feats) { int32 this_chunk_size = std::min(static_cast(feature_cache_.size()), frame_chunk_size_); - feats->Resize(dim_ * this_chunk_size); + feats->resize(dim_ * this_chunk_size); VLOG(3) << "read " << this_chunk_size << " feat."; int32 counter = 0; while (counter < this_chunk_size) { - Vector& val = feature_cache_.front(); - CHECK(val.Dim() == dim_) << val.Dim(); + vector& val = feature_cache_.front(); + CHECK(val.size() == dim_) << val.size(); int32 start = counter * dim_; - feats->Range(start, dim_).CopyFromVec(val); + std::memcpy(feats->data() + start, + val.data(), val.size() * sizeof(BaseFloat)); if (this_chunk_size - counter <= cache_size_) { feature_cache_.push(val); @@ -115,7 +116,7 @@ bool Assembler::Compute(Vector* feats) { void Assembler::Reset() { - std::queue> empty; + std::queue> empty; std::swap(feature_cache_, empty); nframes_ = 0; base_extractor_->Reset(); diff --git a/speechx/speechx/common/frontend/audio/assembler.h b/speechx/speechx/common/frontend/audio/assembler.h index 72e6f635..700e60d9 100644 --- a/speechx/speechx/common/frontend/audio/assembler.h +++ b/speechx/speechx/common/frontend/audio/assembler.h @@ -36,10 +36,10 @@ class Assembler : public FrontendInterface { std::unique_ptr base_extractor = NULL); // Feed feats or waves - void Accept(const kaldi::VectorBase& inputs) override; + void Accept(const std::vector& inputs) override; // feats size = num_frames * feat_dim - bool Read(kaldi::Vector* feats) override; + bool Read(std::vector* feats) override; // feat dim size_t Dim() const override { return dim_; } @@ -51,7 +51,7 @@ class Assembler : public FrontendInterface { void Reset() override; private: - bool Compute(kaldi::Vector* feats); + bool Compute(std::vector* feats); bool fill_zero_{false}; @@ -60,7 +60,7 @@ class Assembler : public FrontendInterface { int32 frame_chunk_stride_; // stride int32 cache_size_; // window - stride int32 receptive_filed_length_; - std::queue> feature_cache_; + std::queue> feature_cache_; std::unique_ptr base_extractor_; int32 nframes_; // num frame computed diff --git a/speechx/speechx/common/frontend/audio/audio_cache.cc b/speechx/speechx/common/frontend/audio/audio_cache.cc index c6a91f4b..2221e1c9 100644 --- a/speechx/speechx/common/frontend/audio/audio_cache.cc +++ b/speechx/speechx/common/frontend/audio/audio_cache.cc @@ -19,8 +19,7 @@ namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; +using std::vector; AudioCache::AudioCache(int buffer_size, bool to_float32) : finished_(false), @@ -37,25 +36,25 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) { return val * (1. / std::pow(2.0, 15)); } -void AudioCache::Accept(const VectorBase& waves) { +void AudioCache::Accept(const vector& waves) { kaldi::Timer timer; std::unique_lock lock(mutex_); - while (size_ + waves.Dim() > ring_buffer_.size()) { + while (size_ + waves.size() > ring_buffer_.size()) { ready_feed_condition_.wait(lock); } - for (size_t idx = 0; idx < waves.Dim(); ++idx) { + for (size_t idx = 0; idx < waves.size(); ++idx) { int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); - ring_buffer_[buffer_idx] = waves(idx); - if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); + ring_buffer_[buffer_idx] = waves[idx]; + if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves[idx]); } - size_ += waves.Dim(); + size_ += waves.size(); VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. " - << waves.Dim() << " samples."; + << waves.size() << " samples."; } -bool AudioCache::Read(Vector* waves) { +bool AudioCache::Read(vector* waves) { kaldi::Timer timer; - size_t chunk_size = waves->Dim(); + size_t chunk_size = waves->size(); std::unique_lock lock(mutex_); while (chunk_size > size_) { // when audio is empty and no more data feed @@ -78,12 +77,12 @@ bool AudioCache::Read(Vector* waves) { // read last chunk data if (chunk_size > size_) { chunk_size = size_; - waves->Resize(chunk_size); + waves->resize(chunk_size); } for (size_t idx = 0; idx < chunk_size; ++idx) { int buff_idx = (offset_ + idx) % ring_buffer_.size(); - waves->Data()[idx] = ring_buffer_[buff_idx]; + waves->at(idx) = ring_buffer_[buff_idx]; } size_ -= chunk_size; offset_ = (offset_ + chunk_size) % ring_buffer_.size(); diff --git a/speechx/speechx/common/frontend/audio/audio_cache.h b/speechx/speechx/common/frontend/audio/audio_cache.h index 4708a6e0..d3cfbc3f 100644 --- a/speechx/speechx/common/frontend/audio/audio_cache.h +++ b/speechx/speechx/common/frontend/audio/audio_cache.h @@ -26,9 +26,9 @@ class AudioCache : public FrontendInterface { explicit AudioCache(int buffer_size = 1000 * kint16max, bool to_float32 = false); - virtual void Accept(const kaldi::VectorBase& waves); + virtual void Accept(const std::vector& waves); - virtual bool Read(kaldi::Vector* waves); + virtual bool Read(std::vector* waves); // the audio dim is 1, one sample, which is useless, // so we return size_(cache samples) instead. diff --git a/speechx/speechx/common/frontend/audio/cmvn.cc b/speechx/speechx/common/frontend/audio/cmvn.cc index a4d861d2..58ec299c 100644 --- a/speechx/speechx/common/frontend/audio/cmvn.cc +++ b/speechx/speechx/common/frontend/audio/cmvn.cc @@ -15,15 +15,12 @@ #include "frontend/audio/cmvn.h" -#include "kaldi/feat/cmvn.h" -#include "kaldi/util/kaldi-io.h" +#include "utils/file_utils.h" +#include "utils/picojson.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; using std::vector; @@ -32,22 +29,46 @@ CMVN::CMVN(std::string cmvn_file, unique_ptr base_extractor) : var_norm_(true) { CHECK_NE(cmvn_file, ""); base_extractor_ = std::move(base_extractor); + ReadCMVNFromJson(cmvn_file); + dim_ = mean_stats_.size() - 1; +} + +void CMVN::ReadCMVNFromJson(string cmvn_file) { + std::string json_str = ppspeech::ReadFile2String(cmvn_file); + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { + LOG(ERROR) << "Input json file format error."; + } + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); + for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { + mean_stats_.push_back((*it).get()); + } + + const picojson::value::array& var_stat = + value.get("var_stat").get(); + for (auto it = var_stat.begin(); it != var_stat.end(); it++) { + var_stats_.push_back((*it).get()); + } - bool binary; - kaldi::Input ki(cmvn_file, &binary); - stats_.Read(ki.Stream(), binary); - dim_ = stats_.NumCols() - 1; + kaldi::int32 frame_num = value.get("frame_num").get(); + LOG(INFO) << "nframe: " << frame_num; + mean_stats_.push_back(frame_num); + var_stats_.push_back(0); } -void CMVN::Accept(const kaldi::VectorBase& inputs) { +void CMVN::Accept(const std::vector& inputs) { // feed waves/feats to compute feature base_extractor_->Accept(inputs); return; } -bool CMVN::Read(kaldi::Vector* feats) { +bool CMVN::Read(std::vector* feats) { // compute feature - if (base_extractor_->Read(feats) == false || feats->Dim() == 0) { + if (base_extractor_->Read(feats) == false || feats->size() == 0) { return false; } @@ -59,74 +80,78 @@ bool CMVN::Read(kaldi::Vector* feats) { } // feats contain num_frames feature. -void CMVN::Compute(VectorBase* feats) const { +void CMVN::Compute(vector* feats) const { KALDI_ASSERT(feats != NULL); - if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || - feats->Dim() % dim_ != 0) { - KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ',' - << stats_.NumCols() - 1 << ", feats " << feats->Dim() << 'x'; + if (feats->size() % dim_ != 0) { + LOG(ERROR)<< "Dim mismatch: cmvn " << mean_stats_.size() << ',' + << var_stats_.size() - 1 << ", feats " << feats->size() << 'x'; } - if (stats_.NumRows() == 1 && var_norm_) { - KALDI_ERR + if (var_stats_.size() == 0 && var_norm_) { + LOG(ERROR) << "You requested variance normalization but no variance stats_ " << "are supplied."; } - double count = stats_(0, dim_); + double count = mean_stats_[dim_]; // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // computing an offset and representing it as stats_, we use a count of one. if (count < 1.0) - KALDI_ERR << "Insufficient stats_ for cepstral mean and variance " + LOG(ERROR) << "Insufficient stats_ for cepstral mean and variance " "normalization: " << "count = " << count; if (!var_norm_) { - Vector offset(feats->Dim()); - SubVector mean_stats(stats_.RowData(0), dim_); - Vector mean_stats_apply(feats->Dim()); + vector offset(feats->size()); + vector mean_stats(mean_stats_); + for (size_t i = 0; i < mean_stats.size(); ++i) { + mean_stats[i] /= count; + } + vector mean_stats_apply(feats->size()); // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal // with the dim_ of feature. // the dim_ of feats = dim_ * num_frames; - for (int32 idx = 0; idx < feats->Dim() / dim_; ++idx) { - SubVector stats_tmp(mean_stats_apply.Data() + dim_ * idx, - dim_); - stats_tmp.CopyFromVec(mean_stats); + for (int32 idx = 0; idx < feats->size() / dim_; ++idx) { + std::memcpy(mean_stats_apply.data() + dim_ * idx, + mean_stats.data(), dim_* sizeof(double)); + } + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += offset[idx]; } - offset.AddVec(-1.0 / count, mean_stats_apply); - feats->AddVec(1.0, offset); return; } // norm(0, d) = mean offset; // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). - kaldi::Matrix norm(2, feats->Dim()); + vector norm0(feats->size()); + vector norm1(feats->size()); for (int32 d = 0; d < dim_; d++) { double mean, offset, scale; - mean = stats_(0, d) / count; - double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; + mean = mean_stats_[d] / count; + double var = (var_stats_[d] / count) - mean * mean, floor = 1.0e-20; if (var < floor) { - KALDI_WARN << "Flooring cepstral variance from " << var << " to " + LOG(WARNING) << "Flooring cepstral variance from " << var << " to " << floor; var = floor; } scale = 1.0 / sqrt(var); if (scale != scale || 1 / scale == 0.0) - KALDI_ERR + LOG(ERROR) << "NaN or infinity in cepstral mean/variance computation"; offset = -(mean * scale); - for (int32 d_skip = d; d_skip < feats->Dim();) { - norm(0, d_skip) = offset; - norm(1, d_skip) = scale; + for (int32 d_skip = d; d_skip < feats->size();) { + norm0[d_skip] = offset; + norm1[d_skip] = scale; d_skip = d_skip + dim_; } } // Apply the normalization. - feats->MulElements(norm.Row(1)); - feats->AddVec(1.0, norm.Row(0)); -} + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) *= norm1[idx]; + } -void CMVN::ApplyCMVN(kaldi::MatrixBase* feats) { - ApplyCmvn(stats_, var_norm_, feats); + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += norm0[idx]; + } } } // namespace ppspeech diff --git a/speechx/speechx/common/frontend/audio/cmvn.h b/speechx/speechx/common/frontend/audio/cmvn.h index 50ef5649..261d90b2 100644 --- a/speechx/speechx/common/frontend/audio/cmvn.h +++ b/speechx/speechx/common/frontend/audio/cmvn.h @@ -25,11 +25,11 @@ class CMVN : public FrontendInterface { public: explicit CMVN(std::string cmvn_file, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // the length of feats = feature_row * feature_dim, // the Matrix is squashed into Vector - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // the dim_ is the feautre dim. virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } @@ -37,9 +37,10 @@ class CMVN : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); } private: - void Compute(kaldi::VectorBase* feats) const; - void ApplyCMVN(kaldi::MatrixBase* feats); - kaldi::Matrix stats_; + void ReadCMVNFromJson(std::string cmvn_file); + void Compute(std::vector* feats) const; + std::vector mean_stats_; + std::vector var_stats_; std::unique_ptr base_extractor_; size_t dim_; bool var_norm_; diff --git a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc deleted file mode 100644 index 8c65b346..00000000 --- a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Note: Do not print/log ondemand object. - -#include "base/common.h" -#include "base/flags.h" -#include "base/log.h" -#include "kaldi/matrix/kaldi-matrix.h" -#include "kaldi/util/kaldi-io.h" -#include "utils/file_utils.h" -#include "utils/picojson.h" - -DEFINE_string(json_file, "", "cmvn json file"); -DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); -DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; - - auto ifs = std::ifstream(FLAGS_json_file); - std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file); - picojson::value value; - std::string err; - const char* json_end = picojson::parse( - value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); - if (!value.is()) { - LOG(ERROR) << "Input json file format error."; - } - - const picojson::value::object& obj = value.get(); - for (picojson::value::object::const_iterator elem = obj.begin(); - elem != obj.end(); - ++elem) { - if (elem->first == "mean_stat") { - VLOG(2) << "mean_stat:" << elem->second; - // const picojson::value tmp = - // elem->second.get(0);//(); - double tmp = - elem->second.get(0).get(); //(); - VLOG(2) << "tmp: " << tmp; - } - if (elem->first == "var_stat") { - VLOG(2) << "var_stat: " << elem->second; - } - if (elem->first == "frame_num") { - VLOG(2) << "frame_num: " << elem->second; - } - } - - const picojson::value::array& mean_stat = - value.get("mean_stat").get(); - std::vector mean_stat_vec; - for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { - mean_stat_vec.push_back((*it).get()); - } - - const picojson::value::array& var_stat = - value.get("var_stat").get(); - std::vector var_stat_vec; - for (auto it = var_stat.begin(); it != var_stat.end(); it++) { - var_stat_vec.push_back((*it).get()); - } - - kaldi::int32 frame_num = value.get("frame_num").get(); - LOG(INFO) << "nframe: " << frame_num; - - size_t mean_size = mean_stat_vec.size(); - kaldi::Matrix cmvn_stats(2, mean_size + 1); - for (size_t idx = 0; idx < mean_size; ++idx) { - cmvn_stats(0, idx) = mean_stat_vec[idx]; - cmvn_stats(1, idx) = var_stat_vec[idx]; - } - cmvn_stats(0, mean_size) = frame_num; - VLOG(2) << cmvn_stats; - - kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); - LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; - LOG(INFO) << "Binary: " << FLAGS_binary; - return 0; -} diff --git a/speechx/speechx/common/frontend/audio/compute_fbank_main.cc b/speechx/speechx/common/frontend/audio/compute_fbank_main.cc index e2b54a8a..fc6eb063 100644 --- a/speechx/speechx/common/frontend/audio/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/audio/compute_fbank_main.cc @@ -56,7 +56,7 @@ int main(int argc, char* argv[]) { std::unique_ptr data_source( new ppspeech::AudioCache(3600 * 1600, false)); - kaldi::FbankOptions opt; + knf::FbankOptions opt; opt.frame_opts.frame_length_ms = 25; opt.frame_opts.frame_shift_ms = 10; opt.mel_opts.num_bins = FLAGS_num_bins; @@ -117,9 +117,9 @@ int main(int argc, char* argv[]) { std::min(chunk_sample_size, tot_samples - sample_offset); // get chunk wav - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } // compute feat @@ -131,10 +131,14 @@ int main(int argc, char* argv[]) { } // read feat - kaldi::Vector features; + kaldi::Vector features(feature_cache.Dim()); bool flag = true; do { - flag = feature_cache.Read(&features); + std::vector tmp; + flag = feature_cache.Read(&tmp); + std::memcpy(features.Data(), + tmp.data(), + tmp.size() * sizeof(BaseFloat)); if (flag && features.Dim() != 0) { feats.push_back(features); feature_rows += features.Dim() / feature_cache.Dim(); diff --git a/speechx/speechx/common/frontend/audio/data_cache.h b/speechx/speechx/common/frontend/audio/data_cache.h index 5fe5e4fe..d18d444d 100644 --- a/speechx/speechx/common/frontend/audio/data_cache.h +++ b/speechx/speechx/common/frontend/audio/data_cache.h @@ -15,10 +15,10 @@ #pragma once - #include "base/common.h" #include "frontend/audio/frontend_itf.h" +using std::vector; namespace ppspeech { @@ -30,16 +30,16 @@ class DataCache : public FrontendInterface { DataCache() : finished_{false}, dim_{0} {} // accept waves/feats - void Accept(const kaldi::VectorBase& inputs) override { - data_ = inputs; + void Accept(const std::vector& inputs) override { + data_ = std::move(inputs); } - bool Read(kaldi::Vector* feats) override { - if (data_.Dim() == 0) { + bool Read(vector* feats) override { + if (data_.size() == 0) { return false; } - (*feats) = data_; - data_.Resize(0); + (*feats) = std::move(data_); + data_.resize(0); return true; } @@ -53,7 +53,7 @@ class DataCache : public FrontendInterface { } private: - kaldi::Vector data_; + std::vector data_; bool finished_; int32 dim_; diff --git a/speechx/speechx/common/frontend/audio/fbank.h b/speechx/speechx/common/frontend/audio/fbank.h index a1e65413..434ae7d6 100644 --- a/speechx/speechx/common/frontend/audio/fbank.h +++ b/speechx/speechx/common/frontend/audio/fbank.h @@ -16,35 +16,10 @@ #include "base/common.h" #include "frontend/audio/feature_common.h" -#include "frontend/audio/frontend_itf.h" -#include "kaldi/feat/feature-fbank.h" -#include "kaldi/feat/feature-mfcc.h" -#include "kaldi/matrix/kaldi-vector.h" +#include "frontend/audio/feature-fbank.h" namespace ppspeech { -class FbankComputer { - public: - typedef kaldi::FbankOptions Options; - explicit FbankComputer(const Options& opts); - - kaldi::FrameExtractionOptions& GetFrameOptions() { - return opts_.frame_opts; - } - - bool Compute(kaldi::Vector* window, - kaldi::Vector* feat); - int32 Dim() const; - - bool NeedRawLogEnergy(); - - private: - Options opts_; - - kaldi::FbankComputer computer_; - DISALLOW_COPY_AND_ASSIGN(FbankComputer); -}; - -typedef StreamingFeatureTpl Fbank; +typedef StreamingFeatureTpl Fbank; } // namespace ppspeech diff --git a/speechx/speechx/common/frontend/audio/feature-fbank.cc b/speechx/speechx/common/frontend/audio/feature-fbank.cc new file mode 100644 index 00000000..7a6da943 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-fbank.cc @@ -0,0 +1,123 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.cc +// +#include "frontend/audio/feature-fbank.h" + +#include + +#include "frontend/audio/feature-functions.h" + +namespace knf { + +static void Sqrt(float *in_out, int32_t n) { + for (int32_t i = 0; i != n; ++i) { + in_out[i] = std::sqrt(in_out[i]); + } +} + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) { + os << opts.ToString(); + return os; +} + +FbankComputer::FbankComputer(const FbankOptions &opts) + : opts_(opts), rfft_(opts.frame_opts.PaddedWindowSize()) { + if (opts.energy_floor > 0.0f) { + log_energy_floor_ = logf(opts.energy_floor); + } + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0f); +} + +FbankComputer::~FbankComputer() { + for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) + delete iter->second; +} + +const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) { + MelBanks *this_mel_banks = nullptr; + + // std::map::iterator iter = mel_banks_.find(vtln_warp); + auto iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = + new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +void FbankComputer::Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature) { + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + CHECK_EQ(signal_frame->size(), opts_.frame_opts.PaddedWindowSize()); + + // Compute energy after window function (not the raw one). + if (opts_.use_energy && !opts_.raw_energy) { + signal_raw_log_energy = + std::log(std::max(InnerProduct(signal_frame->data(), + signal_frame->data(), + signal_frame->size()), + std::numeric_limits::epsilon())); + } + rfft_.Compute(signal_frame->data()); // signal_frame is modified in-place + ComputePowerSpectrum(signal_frame); + + // Use magnitude instead of power if requested. + if (!opts_.use_power) { + Sqrt(signal_frame->data(), signal_frame->size() / 2 + 1); + } + + int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + + // Its length is opts_.mel_opts.num_bins + float *mel_energies = feature + mel_offset; + + // Sum with mel filter banks over the power spectrum + mel_banks.Compute(signal_frame->data(), mel_energies); + + if (opts_.use_log_fbank) { + // Avoid log of zero (which should be prevented anyway by dithering). + for (int32_t i = 0; i != opts_.mel_opts.num_bins; ++i) { + auto t = std::max(mel_energies[i], + std::numeric_limits::epsilon()); + mel_energies[i] = std::log(t); + } + } + + // Copy energy as first value (or the last, if htk_compat == true). + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && + signal_raw_log_energy < log_energy_floor_) { + signal_raw_log_energy = log_energy_floor_; + } + int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; + feature[energy_index] = signal_raw_log_energy; + } +} + +} // namespace knf diff --git a/speechx/speechx/common/frontend/audio/feature-fbank.h b/speechx/speechx/common/frontend/audio/feature-fbank.h new file mode 100644 index 00000000..3c43a3c8 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-fbank.h @@ -0,0 +1,137 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.h + +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ + +#include + +#include "frontend/audio/feature-window.h" +#include "frontend/audio/mel-computations.h" +#include "frontend/audio/rfft.h" + +namespace knf { + +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + // append an extra dimension with energy to the filter banks + bool use_energy = false; + float energy_floor = 0.0f; // active iff use_energy==true + + // If true, compute log_energy before preemphasis and windowing + // If false, compute log_energy after preemphasis ans windowing + bool raw_energy = true; // active iff use_energy==true + + // If true, put energy last (if using energy) + // If false, put energy first + bool htk_compat = false; // active iff use_energy==true + + // if true (default), produce log-filterbank, else linear + bool use_log_fbank = true; + + // if true (default), use power in filterbank + // analysis, else magnitude. + bool use_power = true; + + FbankOptions() { mel_opts.num_bins = 23; } + + std::string ToString() const { + std::ostringstream os; + os << "frame_opts: \n"; + os << frame_opts << "\n"; + os << "\n"; + + os << "mel_opts: \n"; + os << mel_opts << "\n"; + + os << "use_energy: " << use_energy << "\n"; + os << "energy_floor: " << energy_floor << "\n"; + os << "raw_energy: " << raw_energy << "\n"; + os << "htk_compat: " << htk_compat << "\n"; + os << "use_log_fbank: " << use_log_fbank << "\n"; + os << "use_power: " << use_power << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts); + +class FbankComputer { + public: + using Options = FbankOptions; + + explicit FbankComputer(const FbankOptions &opts); + ~FbankComputer(); + + int32_t Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + // if true, compute log_energy_pre_window but after dithering and dc removal + bool NeedRawLogEnergy() const { + return opts_.use_energy && opts_.raw_energy; + } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + const FbankOptions &GetOptions() const { return opts_; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the + signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. It should be pre-allocated. + */ + void Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature); + + private: + const MelBanks *GetMelBanks(float vtln_warp); + + FbankOptions opts_; + float log_energy_floor_; + std::map mel_banks_; // float is VTLN coefficient. + Rfft rfft_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ diff --git a/speechx/speechx/common/frontend/audio/feature-functions.cc b/speechx/speechx/common/frontend/audio/feature-functions.cc new file mode 100644 index 00000000..399041e4 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-functions.cc @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.cc + +#include "frontend/audio/feature-functions.h" + +#include +#include + +namespace knf { + +void ComputePowerSpectrum(std::vector *complex_fft) { + int32_t dim = complex_fft->size(); + + // now we have in complex_fft, first half of complex spectrum + // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] + + float *p = complex_fft->data(); + int32_t half_dim = dim / 2; + float first_energy = p[0] * p[0]; + float last_energy = p[1] * p[1]; // handle this special case + + for (int32_t i = 1; i < half_dim; ++i) { + float real = p[i * 2]; + float im = p[i * 2 + 1]; + p[i] = real * real + im * im; + } + p[0] = first_energy; + p[half_dim] = last_energy; // Will actually never be used, and anyway + // if the signal has been bandlimited sensibly this should be zero. +} + +} // namespace knf diff --git a/speechx/speechx/common/frontend/audio/feature-functions.h b/speechx/speechx/common/frontend/audio/feature-functions.h new file mode 100644 index 00000000..852d0612 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-functions.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.h +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H + +#include +namespace knf { + +// ComputePowerSpectrum converts a complex FFT (as produced by the FFT +// functions in csrc/rfft.h), and converts it into +// a power spectrum. If the complex FFT is a vector of size n (representing +// half of the complex FFT of a real signal of size n, as described there), +// this function computes in the first (n/2) + 1 elements of it, the +// energies of the fft bins from zero to the Nyquist frequency. Contents of the +// remaining (n/2) - 1 elements are undefined at output. + +void ComputePowerSpectrum(std::vector *complex_fft); + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H diff --git a/speechx/speechx/common/frontend/audio/feature-window.cc b/speechx/speechx/common/frontend/audio/feature-window.cc new file mode 100644 index 00000000..7778a1b9 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-window.cc @@ -0,0 +1,247 @@ +// kaldi-native-fbank/csrc/feature-window.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.cc + +#include "frontend/audio/feature-window.h" + +#include +#include + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) { + os << opts.ToString(); + return os; +} + +FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) + : window_(opts.WindowSize()) { + int32_t frame_length = opts.WindowSize(); + CHECK_GT(frame_length, 0); + + float *window_data = window_.data(); + + double a = M_2PI / (frame_length - 1); + for (int32_t i = 0; i < frame_length; i++) { + double i_fl = static_cast(i); + if (opts.window_type == "hanning") { + window_data[i] = 0.5 - 0.5 * cos(a * i_fl); + } else if (opts.window_type == "sine") { + // when you are checking ws wikipedia, please + // note that 0.5 * a = M_PI/(frame_length-1) + window_data[i] = sin(0.5 * a * i_fl); + } else if (opts.window_type == "hamming") { + window_data[i] = 0.54 - 0.46 * cos(a * i_fl); + } else if (opts.window_type == + "povey") { // like hamming but goes to zero at edges. + window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85); + } else if (opts.window_type == "rectangular") { + window_data[i] = 1.0; + } else if (opts.window_type == "blackman") { + window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) + + (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); + } else { + LOG(FATAL) << "Invalid window type " << opts.window_type; + } + } +} + +void FeatureWindowFunction::Apply(float *wave) const { + int32_t window_size = window_.size(); + const float *p = window_.data(); + for (int32_t k = 0; k != window_size; ++k) { + wave[k] *= p[k]; + } +} + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) { + int64_t frame_shift = opts.WindowShift(); + if (opts.snip_edges) { + return frame * frame_shift; + } else { + int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; + } +} + +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush /*= true*/) { + int64_t frame_shift = opts.WindowShift(); + int64_t frame_length = opts.WindowSize(); + if (opts.snip_edges) { + // with --snip-edges=true (the default), we use a HTK-like approach to + // determining the number of frames-- all frames have to fit completely + // into + // the waveform, and the first frame begins at sample zero. + if (num_samples < frame_length) + return 0; + else + return (1 + ((num_samples - frame_length) / frame_shift)); + // You can understand the expression above as follows: 'num_samples - + // frame_length' is how much room we have to shift the frame within the + // waveform; 'frame_shift' is how much we shift it each time; and the + // ratio + // is how many times we can shift it (integer arithmetic rounds down). + } else { + // if --snip-edges=false, the number of frames is determined by rounding + // the + // (file-length / frame-shift) to the nearest integer. The point of + // this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) return num_frames; + + // note: 'end' always means the last plus one, i.e. one past the last. + int64_t end_sample_of_last_frame = + FirstSampleOfFrame(num_frames - 1, opts) + frame_length; + + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; + } +} + +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window /*= nullptr*/) { + CHECK(sample_offset >= 0 && wave.size() != 0); + + int32_t frame_length = opts.WindowSize(); + int32_t frame_length_padded = opts.PaddedWindowSize(); + + int64_t num_samples = sample_offset + wave.size(); + int64_t start_sample = FirstSampleOfFrame(f, opts); + int64_t end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + CHECK(start_sample >= sample_offset && end_sample <= num_samples); + } else { + CHECK(sample_offset == 0 || start_sample >= sample_offset); + } + + if (window->size() != frame_length_padded) { + window->resize(frame_length_padded); + } + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_start = int32_t(start_sample - sample_offset); + int32_t wave_end = wave_start + frame_length; + + if (wave_start >= 0 && wave_end <= wave.size()) { + // the normal case-- no edge effects to consider. + std::copy(wave.begin() + wave_start, + wave.begin() + wave_start + frame_length, + window->data()); + } else { + // Deal with any end effects by reflection, if needed. This code will + // only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32_t wave_dim = wave.size(); + for (int32_t s = 0; s < frame_length; ++s) { + int32_t s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) + s_in_wave = -s_in_wave - 1; + else + s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + (*window)[s] = wave[s_in_wave]; + } + } + + ProcessWindow(opts, window_function, window->data(), log_energy_pre_window); +} + +static void RemoveDcOffset(float *d, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += d[i]; + } + + float mean = sum / n; + + for (int32_t i = 0; i != n; ++i) { + d[i] -= mean; + } +} + +float InnerProduct(const float *a, const float *b, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +static void Preemphasize(float *d, int32_t n, float preemph_coeff) { + if (preemph_coeff == 0.0) { + return; + } + + CHECK(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); + + for (int32_t i = n - 1; i > 0; --i) { + d[i] -= preemph_coeff * d[i - 1]; + } + d[0] -= preemph_coeff * d[0]; +} + +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window /*= nullptr*/) { + int32_t frame_length = opts.WindowSize(); + + // TODO(fangjun): Remove dither + CHECK_EQ(opts.dither, 0); + + if (opts.remove_dc_offset) { + RemoveDcOffset(window, frame_length); + } + + if (log_energy_pre_window != NULL) { + float energy = + std::max(InnerProduct(window, window, frame_length), + std::numeric_limits::epsilon()); + *log_energy_pre_window = std::log(energy); + } + + if (opts.preemph_coeff != 0.0) { + Preemphasize(window, frame_length, opts.preemph_coeff); + } + + window_function.Apply(window); +} + +} // namespace knf diff --git a/speechx/speechx/common/frontend/audio/feature-window.h b/speechx/speechx/common/frontend/audio/feature-window.h new file mode 100644 index 00000000..8c86bf05 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/feature-window.h @@ -0,0 +1,183 @@ +// kaldi-native-fbank/csrc/feature-window.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.h + +#ifndef KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ +#define KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ + +#include +#include +#include + +#include "base/log.h" + +namespace knf { + +inline int32_t RoundUpToNearestPowerOfTwo(int32_t n) { + // copied from kaldi/src/base/kaldi-math.cc + CHECK_GT(n, 0); + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +struct FrameExtractionOptions { + float samp_freq = 16000; + float frame_shift_ms = 10.0f; // in milliseconds. + float frame_length_ms = 25.0f; // in milliseconds. + float dither = 1.0f; // Amount of dithering, 0.0 means no dither. + float preemph_coeff = 0.97f; // Preemphasis coefficient. + bool remove_dc_offset = true; // Subtract mean of wave before FFT. + std::string window_type = "povey"; // e.g. Hamming window + // May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman" + // "povey" is a window I made to be similar to Hamming but to go to zero at + // the edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) I just don't think + // the + // Hamming window makes sense as a windowing function. + bool round_to_power_of_two = true; + float blackman_coeff = 0.42f; + bool snip_edges = true; + // bool allow_downsample = false; + // bool allow_upsample = false; + + // Used for streaming feature extraction. It indicates the number + // of feature frames to keep in the recycling vector. -1 means to + // keep all feature frames. + int32_t max_feature_vectors = -1; + + int32_t WindowShift() const { + return static_cast(samp_freq * 0.001f * frame_shift_ms); + } + int32_t WindowSize() const { + return static_cast(samp_freq * 0.001f * frame_length_ms); + } + int32_t PaddedWindowSize() const { + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) + : WindowSize()); + } + std::string ToString() const { + std::ostringstream os; +#define KNF_PRINT(x) os << #x << ": " << x << "\n" + KNF_PRINT(samp_freq); + KNF_PRINT(frame_shift_ms); + KNF_PRINT(frame_length_ms); + KNF_PRINT(dither); + KNF_PRINT(preemph_coeff); + KNF_PRINT(remove_dc_offset); + KNF_PRINT(window_type); + KNF_PRINT(round_to_power_of_two); + KNF_PRINT(blackman_coeff); + KNF_PRINT(snip_edges); + // KNF_PRINT(allow_downsample); + // KNF_PRINT(allow_upsample); + KNF_PRINT(max_feature_vectors); +#undef KNF_PRINT + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts); + +class FeatureWindowFunction { + public: + FeatureWindowFunction() = default; + explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + /** + * @param wave Pointer to a 1-D array of shape [window_size]. + * It is modified in-place: wave[i] = wave[i] * window_[i]. + * @param + */ + void Apply(float *wave) const; + + private: + std::vector window_; // of size opts.WindowSize() +}; + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts); + +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples + is 'all there is', false if we expecting more data to possibly come in. This + only makes a difference to the answer + if opts.snips_edges== false. For offline feature extraction you always want + flush == true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the end to + flush out any remaining data. +*/ +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush = true); + +/* + ExtractWindow() extracts a windowed frame of waveform (possibly with a + power-of-two, padded size, depending on the config), including all the + processing done by ProcessWindow(). + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) + @param [in] opts The options class to be used + @param [in] window_function The windowing function, as derived from the + options class. + @param [out] window The windowed, possibly-padded waveform to be + extracted. Will be resized as needed. + @param [out] log_energy_pre_window If non-NULL, the log-energy of + the signal prior to pre-emphasis and multiplying by + the windowing function will be written to here. +*/ +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window = nullptr); + +/** + This function does all the windowing steps after actually + extracting the windowed signal: depending on the + configuration, it does dithering, dc offset removal, + preemphasis, and multiplication by the windowing function. + @param [in] opts The options class to be used + @param [in] window_function The windowing function-- should have + been initialized using 'opts'. + @param [in,out] window A vector of size opts.WindowSize(). Note: + it will typically be a sub-vector of a larger vector of size + opts.PaddedWindowSize(), with the remaining samples zero, + as the FFT code is more efficient if it operates on data with + power-of-two size. + @param [out] log_energy_pre_window If non-NULL, then after dithering and + DC offset removal, this function will write to this pointer the log of + the total energy (i.e. sum-squared) of the frame. + */ +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window = nullptr); + +// Compute the inner product of two vectors +float InnerProduct(const float *a, const float *b, int32_t n); + +} // namespace knf + +#endif // KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/speechx/speechx/common/frontend/audio/feature_cache.cc b/speechx/speechx/common/frontend/audio/feature_cache.cc index 5110d704..dc60e3e4 100644 --- a/speechx/speechx/common/frontend/audio/feature_cache.cc +++ b/speechx/speechx/common/frontend/audio/feature_cache.cc @@ -17,9 +17,6 @@ namespace ppspeech { using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; using std::vector; @@ -31,7 +28,7 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts, dim_ = base_extractor_->Dim(); } -void FeatureCache::Accept(const kaldi::VectorBase& inputs) { +void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); @@ -43,7 +40,7 @@ void FeatureCache::Accept(const kaldi::VectorBase& inputs) { } // pop feature chunk -bool FeatureCache::Read(kaldi::Vector* feats) { +bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; std::unique_lock lock(mutex_); @@ -59,8 +56,7 @@ bool FeatureCache::Read(kaldi::Vector* feats) { if (cache_.empty()) return false; // read from cache - feats->Resize(cache_.front().Dim()); - feats->CopyFromVec(cache_.front()); + *feats = cache_.front(); cache_.pop(); ready_feed_condition_.notify_one(); VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; @@ -70,21 +66,20 @@ bool FeatureCache::Read(kaldi::Vector* feats) { // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) return false; + if (result == false || feature.size() == 0) return false; kaldi::Timer timer; - int32 num_chunk = feature.Dim() / dim_; + int32 num_chunk = feature.size() / dim_; nframe_ += num_chunk; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; - Vector feature_chunk(dim_); - SubVector tmp(feature.Data() + start, dim_); - feature_chunk.CopyFromVec(tmp); + vector feature_chunk(feature.data() + start, + feature.data() + start + dim_); std::unique_lock lock(mutex_); while (cache_.size() >= max_size_) { diff --git a/speechx/speechx/common/frontend/audio/feature_cache.h b/speechx/speechx/common/frontend/audio/feature_cache.h index a4ebd604..8d17151c 100644 --- a/speechx/speechx/common/frontend/audio/feature_cache.h +++ b/speechx/speechx/common/frontend/audio/feature_cache.h @@ -32,10 +32,10 @@ class FeatureCache : public FrontendInterface { std::unique_ptr base_extractor = NULL); // Feed feats or waves - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // feats size = num_frames * feat_dim - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // feat dim virtual size_t Dim() const { return dim_; } @@ -54,7 +54,7 @@ class FeatureCache : public FrontendInterface { virtual bool IsFinished() const { return base_extractor_->IsFinished(); } void Reset() override { - std::queue> empty; + std::queue> empty; std::swap(cache_, empty); nframe_ = 0; base_extractor_->Reset(); @@ -71,8 +71,8 @@ class FeatureCache : public FrontendInterface { std::unique_ptr base_extractor_; kaldi::int32 timeout_; // ms - kaldi::Vector remained_feature_; - std::queue> cache_; // feature cache + std::vector remained_feature_; + std::queue> cache_; // feature cache std::mutex mutex_; std::condition_variable ready_feed_condition_; std::condition_variable ready_read_condition_; diff --git a/speechx/speechx/common/frontend/audio/feature_common.h b/speechx/speechx/common/frontend/audio/feature_common.h index bad705c9..f88dd960 100644 --- a/speechx/speechx/common/frontend/audio/feature_common.h +++ b/speechx/speechx/common/frontend/audio/feature_common.h @@ -15,7 +15,7 @@ #pragma once #include "frontend_itf.h" -#include "kaldi/feat/feature-window.h" +#include "frontend/audio/feature-window.h" namespace ppspeech { @@ -25,8 +25,8 @@ class StreamingFeatureTpl : public FrontendInterface { typedef typename F::Options Options; StreamingFeatureTpl(const Options& opts, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& waves); - virtual bool Read(kaldi::Vector* feats); + virtual void Accept(const std::vector& waves); + virtual bool Read(std::vector* feats); // the dim_ is the dim of single frame feature virtual size_t Dim() const { return computer_.Dim(); } @@ -37,16 +37,16 @@ class StreamingFeatureTpl : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); - remained_wav_.Resize(0); + remained_wav_.resize(0); } private: - bool Compute(const kaldi::Vector& waves, - kaldi::Vector* feats); + bool Compute(const std::vector& waves, + std::vector* feats); Options opts_; std::unique_ptr base_extractor_; - kaldi::FeatureWindowFunction window_function_; - kaldi::Vector remained_wav_; + knf::FeatureWindowFunction window_function_; + std::vector remained_wav_; F computer_; }; diff --git a/speechx/speechx/common/frontend/audio/feature_common_inl.h b/speechx/speechx/common/frontend/audio/feature_common_inl.h index dcf44ef6..ac239974 100644 --- a/speechx/speechx/common/frontend/audio/feature_common_inl.h +++ b/speechx/speechx/common/frontend/audio/feature_common_inl.h @@ -24,75 +24,77 @@ StreamingFeatureTpl::StreamingFeatureTpl( template void StreamingFeatureTpl::Accept( - const kaldi::VectorBase& waves) { + const std::vector& waves) { base_extractor_->Accept(waves); } template -bool StreamingFeatureTpl::Read(kaldi::Vector* feats) { - kaldi::Vector wav(base_extractor_->Dim()); +bool StreamingFeatureTpl::Read(std::vector* feats) { + std::vector wav(base_extractor_->Dim()); bool flag = base_extractor_->Read(&wav); - if (flag == false || wav.Dim() == 0) return false; + if (flag == false || wav.size() == 0) return false; - kaldi::Timer timer; // append remaned waves - int32 wav_len = wav.Dim(); - int32 left_len = remained_wav_.Dim(); - kaldi::Vector waves(left_len + wav_len); - waves.Range(0, left_len).CopyFromVec(remained_wav_); - waves.Range(left_len, wav_len).CopyFromVec(wav); + int32 wav_len = wav.size(); + int32 left_len = remained_wav_.size(); + std::vector waves(left_len + wav_len); + std::memcpy(waves.data(), + remained_wav_.data(), + left_len * sizeof(kaldi::BaseFloat)); + std::memcpy(waves.data() + left_len, + wav.data(), + wav_len * sizeof(kaldi::BaseFloat)); // compute speech feature Compute(waves, feats); // cache remaned waves - kaldi::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); - int32 num_frames = kaldi::NumFrames(waves.Dim(), frame_opts); + knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); + int32 num_frames = knf::NumFrames(waves.size(), frame_opts); int32 frame_shift = frame_opts.WindowShift(); - int32 left_samples = waves.Dim() - frame_shift * num_frames; - remained_wav_.Resize(left_samples); - remained_wav_.CopyFromVec( - waves.Range(frame_shift * num_frames, left_samples)); - VLOG(1) << "StreamingFeatureTpl::Read cost: " << timer.Elapsed() - << " sec."; + int32 left_samples = waves.size() - frame_shift * num_frames; + remained_wav_.resize(left_samples); + std::memcpy(remained_wav_.data(), + waves.data() + frame_shift * num_frames, + left_samples * sizeof(BaseFloat)); return true; } // Compute feat template -bool StreamingFeatureTpl::Compute( - const kaldi::Vector& waves, - kaldi::Vector* feats) { - const kaldi::FrameExtractionOptions& frame_opts = - computer_.GetFrameOptions(); - int32 num_samples = waves.Dim(); +bool StreamingFeatureTpl::Compute(const std::vector& waves, + std::vector* feats) { + const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions(); + int32 num_samples = waves.size(); int32 frame_length = frame_opts.WindowSize(); int32 sample_rate = frame_opts.samp_freq; if (num_samples < frame_length) { return true; } - int32 num_frames = kaldi::NumFrames(num_samples, frame_opts); - feats->Resize(num_frames * Dim()); + int32 num_frames = knf::NumFrames(num_samples, frame_opts); + feats->resize(num_frames * Dim()); - kaldi::Vector window; + std::vector window; bool need_raw_log_energy = computer_.NeedRawLogEnergy(); for (int32 frame = 0; frame < num_frames; frame++) { + std::fill(window.begin(), window.end(), 0); kaldi::BaseFloat raw_log_energy = 0.0; - kaldi::ExtractWindow(0, - waves, - frame, - frame_opts, - window_function_, - &window, - need_raw_log_energy ? &raw_log_energy : NULL); + kaldi::BaseFloat vtln_warp = 1.0; + knf::ExtractWindow(0, + waves, + frame, + frame_opts, + window_function_, + &window, + need_raw_log_energy ? &raw_log_energy : NULL); - kaldi::Vector this_feature(computer_.Dim(), - kaldi::kUndefined); - computer_.Compute(&window, &this_feature); - kaldi::SubVector output_row( - feats->Data() + frame * Dim(), Dim()); - output_row.CopyFromVec(this_feature); + std::vector this_feature(computer_.Dim()); + computer_.Compute( + raw_log_energy, vtln_warp, &window, this_feature.data()); + std::memcpy(feats->data() + frame * Dim(), + this_feature.data(), + sizeof(BaseFloat) * Dim()); } return true; } diff --git a/speechx/speechx/common/frontend/audio/feature_pipeline.cc b/speechx/speechx/common/frontend/audio/feature_pipeline.cc index 2931b96b..8344ee65 100644 --- a/speechx/speechx/common/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/audio/feature_pipeline.cc @@ -21,17 +21,12 @@ using std::unique_ptr; FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { unique_ptr data_source( - new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); + new ppspeech::AudioCache(1000 * kint16max, false)); unique_ptr base_feature; - if (opts.use_fbank) { - base_feature.reset( - new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); - } else { - base_feature.reset(new ppspeech::LinearSpectrogram( - opts.linear_spectrogram_opts, std::move(data_source))); - } + base_feature.reset( + new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); CHECK_NE(opts.cmvn_file, ""); unique_ptr cmvn( diff --git a/speechx/speechx/common/frontend/audio/feature_pipeline.h b/speechx/speechx/common/frontend/audio/feature_pipeline.h index e83a3f31..0afb873e 100644 --- a/speechx/speechx/common/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/common/frontend/audio/feature_pipeline.h @@ -22,11 +22,9 @@ #include "frontend/audio/fbank.h" #include "frontend/audio/feature_cache.h" #include "frontend/audio/frontend_itf.h" -#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/normalizer.h" // feature -DECLARE_bool(use_fbank); DECLARE_bool(fill_zero); DECLARE_int32(num_bins); DECLARE_string(cmvn_file); @@ -40,10 +38,7 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; - bool to_float32{false}; // true, only for linear feature - bool use_fbank{true}; - LinearSpectrogramOptions linear_spectrogram_opts{}; - kaldi::FbankOptions fbank_opts{}; + knf::FbankOptions fbank_opts{}; FeatureCacheOptions feature_cache_opts{}; AssemblerOptions assembler_opts{}; @@ -53,30 +48,17 @@ struct FeaturePipelineOptions { LOG(INFO) << "cmvn file: " << opts.cmvn_file; // frame options - kaldi::FrameExtractionOptions frame_opts; + knf::FrameExtractionOptions frame_opts; frame_opts.dither = 0.0; LOG(INFO) << "dither: " << frame_opts.dither; frame_opts.frame_shift_ms = 10; LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms; - opts.use_fbank = FLAGS_use_fbank; - LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear"); - if (opts.use_fbank) { - opts.to_float32 = false; - frame_opts.window_type = "povey"; - frame_opts.frame_length_ms = 25; - opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; - - opts.fbank_opts.frame_opts = frame_opts; - } else { - opts.to_float32 = true; - frame_opts.remove_dc_offset = false; - frame_opts.frame_length_ms = 20; - frame_opts.window_type = "hanning"; - frame_opts.preemph_coeff = 0.0; - - opts.linear_spectrogram_opts.frame_opts = frame_opts; - } + frame_opts.window_type = "povey"; + frame_opts.frame_length_ms = 25; + opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; + + opts.fbank_opts.frame_opts = frame_opts; LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms; // assembler opts @@ -100,10 +82,10 @@ struct FeaturePipelineOptions { class FeaturePipeline : public FrontendInterface { public: explicit FeaturePipeline(const FeaturePipelineOptions& opts); - virtual void Accept(const kaldi::VectorBase& waves) { + virtual void Accept(const std::vector& waves) { base_extractor_->Accept(waves); } - virtual bool Read(kaldi::Vector* feats) { + virtual bool Read(std::vector* feats) { return base_extractor_->Read(feats); } virtual size_t Dim() const { return base_extractor_->Dim(); } diff --git a/speechx/speechx/common/frontend/audio/fftsg.c b/speechx/speechx/common/frontend/audio/fftsg.c new file mode 100644 index 00000000..ec8217a2 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/fftsg.c @@ -0,0 +1,3271 @@ +/* This file is copied from + * https://www.kurims.kyoto-u.ac.jp/~ooura/fft.html + */ +/* +Fast Fourier/Cosine/Sine Transform + dimension :one + data length :power of 2 + decimation :frequency + radix :split-radix + data :inplace + table :use +functions + cdft: Complex Discrete Fourier Transform + rdft: Real Discrete Fourier Transform + ddct: Discrete Cosine Transform + ddst: Discrete Sine Transform + dfct: Cosine Transform of RDFT (Real Symmetric DFT) + dfst: Sine Transform of RDFT (Real Anti-symmetric DFT) +function prototypes + void cdft(int, int, double *, int *, double *); + void rdft(int, int, double *, int *, double *); + void ddct(int, int, double *, int *, double *); + void ddst(int, int, double *, int *, double *); + void dfct(int, double *, double *, int *, double *); + void dfst(int, double *, double *, int *, double *); +macro definitions + USE_CDFT_PTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=8192 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=65536 + USE_CDFT_WINTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=32768 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=524288 + + +-------- Complex DFT (Discrete Fourier Transform) -------- + [definition] + + X[k] = sum_j=0^n-1 x[j]*exp(2*pi*i*j*k/n), 0<=k + X[k] = sum_j=0^n-1 x[j]*exp(-2*pi*i*j*k/n), 0<=k + ip[0] = 0; // first time only + cdft(2*n, 1, a, ip, w); + + ip[0] = 0; // first time only + cdft(2*n, -1, a, ip, w); + [parameters] + 2*n :data length (int) + n >= 1, n = power of 2 + a[0...2*n-1] :input/output data (double *) + input data + a[2*j] = Re(x[j]), + a[2*j+1] = Im(x[j]), 0<=j= 2+sqrt(n) + strictly, + length of ip >= + 2+(1<<(int)(log(n+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + cdft(2*n, -1, a, ip, w); + is + cdft(2*n, 1, a, ip, w); + for (j = 0; j <= 2 * n - 1; j++) { + a[j] *= 1.0 / n; + } + . + + +-------- Real DFT / Inverse of Real DFT -------- + [definition] + RDFT + R[k] = sum_j=0^n-1 a[j]*cos(2*pi*j*k/n), 0<=k<=n/2 + I[k] = sum_j=0^n-1 a[j]*sin(2*pi*j*k/n), 0 IRDFT (excluding scale) + a[k] = (R[0] + R[n/2]*cos(pi*k))/2 + + sum_j=1^n/2-1 R[j]*cos(2*pi*j*k/n) + + sum_j=1^n/2-1 I[j]*sin(2*pi*j*k/n), 0<=k + ip[0] = 0; // first time only + rdft(n, 1, a, ip, w); + + ip[0] = 0; // first time only + rdft(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + output data + a[2*k] = R[k], 0<=k + input data + a[2*j] = R[j], 0<=j= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + rdft(n, 1, a, ip, w); + is + rdft(n, -1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DCT (Discrete Cosine Transform) / Inverse of DCT -------- + [definition] + IDCT (excluding scale) + C[k] = sum_j=0^n-1 a[j]*cos(pi*j*(k+1/2)/n), 0<=k DCT + C[k] = sum_j=0^n-1 a[j]*cos(pi*(j+1/2)*k/n), 0<=k + ip[0] = 0; // first time only + ddct(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddct(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = C[k], 0<=k= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddct(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddct(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DST (Discrete Sine Transform) / Inverse of DST -------- + [definition] + IDST (excluding scale) + S[k] = sum_j=1^n A[j]*sin(pi*j*(k+1/2)/n), 0<=k DST + S[k] = sum_j=0^n-1 a[j]*sin(pi*(j+1/2)*k/n), 0 + ip[0] = 0; // first time only + ddst(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddst(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + input data + a[j] = A[j], 0 + output data + a[k] = S[k], 0= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddst(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddst(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Cosine Transform of RDFT (Real Symmetric DFT) -------- + [definition] + C[k] = sum_j=0^n a[j]*cos(pi*j*k/n), 0<=k<=n + [usage] + ip[0] = 0; // first time only + dfct(n, a, t, ip, w); + [parameters] + n :data length - 1 (int) + n >= 2, n = power of 2 + a[0...n] :input/output data (double *) + output data + a[k] = C[k], 0<=k<=n + t[0...n/2] :work area (double *) + ip[0...*] :work area for bit reversal (int *) + length of ip >= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + is + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + for (j = 0; j <= n; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Sine Transform of RDFT (Real Anti-symmetric DFT) -------- + [definition] + S[k] = sum_j=1^n-1 a[j]*sin(pi*j*k/n), 0= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = S[k], 0= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + dfst(n, a, t, ip, w); + is + dfst(n, a, t, ip, w); + for (j = 1; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +Appendix : + The cos/sin table is recalculated when the larger table required. + w[] and ip[] are compatible with all routines. +*/ + + +void cdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + int nw; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + if (isgn >= 0) { + cftfsub(n, a, ip, nw, w); + } else { + cftbsub(n, a, ip, nw, w); + } +} + + +void rdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + int nw, nc; + double xi; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 2)) { + nc = n >> 2; + makect(nc, ip, w + nw); + } + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xi = a[0] - a[1]; + a[0] += a[1]; + a[1] = xi; + } else { + a[1] = 0.5 * (a[0] - a[1]); + a[0] -= a[1]; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } +} + + +void ddct(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = a[j] - a[j - 1]; + a[j] += a[j - 1]; + } + a[1] = a[0] - xr; + a[0] += xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dctsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = a[j] - a[j + 1]; + a[j] += a[j + 1]; + } + a[n - 1] = xr; + } +} + + +void ddst(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = -a[j] - a[j - 1]; + a[j] -= a[j - 1]; + } + a[1] = a[0] + xr; + a[0] -= xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dstsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = -a[j] - a[j + 1]; + a[j] -= a[j + 1]; + } + a[n - 1] = -xr; + } +} + + +void dfct(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + m = n >> 1; + yi = a[m]; + xi = a[0] + a[n]; + a[0] -= a[n]; + t[0] = xi - yi; + t[m] = xi + yi; + if (n > 2) { + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] - a[n - j]; + xi = a[j] + a[n - j]; + yr = a[k] - a[n - k]; + yi = a[k] + a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi - yi; + t[k] = xi + yi; + } + t[mh] = a[mh] + a[n - mh]; + a[mh] -= a[n - mh]; + dctsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[0] - a[1]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] + a[j + 1]; + a[2 * j - 1] = a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dctsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[0] - t[1]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = t[j] - t[j + 1]; + a[k + l] = t[j] + t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 0; j < mh; j++) { + k = m - j; + t[j] = t[m + k] - t[m + j]; + t[k] = t[m + k] + t[m + j]; + } + t[mh] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + a[n] = t[2] - t[1]; + a[0] = t[2] + t[1]; + } else { + a[1] = a[0]; + a[2] = t[0]; + a[0] = t[1]; + } +} + + +void dfst(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + if (n > 2) { + m = n >> 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] + a[n - j]; + xi = a[j] - a[n - j]; + yr = a[k] + a[n - k]; + yi = a[k] - a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi + yi; + t[k] = xi - yi; + } + t[0] = a[mh] - a[n - mh]; + a[mh] += a[n - mh]; + a[0] = a[m]; + dstsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[1] - a[0]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] - a[j + 1]; + a[2 * j - 1] = -a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dstsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[1] - t[0]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = -t[j] - t[j + 1]; + a[k + l] = t[j] - t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + t[j] = t[m + k] + t[m + j]; + t[k] = t[m + k] - t[m + j]; + } + t[0] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + } + a[0] = 0; +} + + +/* -------- initializing routines -------- */ + + +#include + +void makewt(int nw, int *ip, double *w) { + void makeipt(int nw, int *ip); + int j, nwh, nw0, nw1; + double delta, wn4r, wk1r, wk1i, wk3r, wk3i; + + ip[0] = nw; + ip[1] = 1; + if (nw > 2) { + nwh = nw >> 1; + delta = atan(1.0) / nwh; + wn4r = cos(delta * nwh); + w[0] = 1; + w[1] = wn4r; + if (nwh == 4) { + w[2] = cos(delta * 2); + w[3] = sin(delta * 2); + } else if (nwh > 4) { + makeipt(nw, ip); + w[2] = 0.5 / cos(delta * 2); + w[3] = 0.5 / cos(delta * 6); + for (j = 4; j < nwh; j += 4) { + w[j] = cos(delta * j); + w[j + 1] = sin(delta * j); + w[j + 2] = cos(3 * delta * j); + w[j + 3] = -sin(3 * delta * j); + } + } + nw0 = 0; + while (nwh > 2) { + nw1 = nw0 + nwh; + nwh >>= 1; + w[nw1] = 1; + w[nw1 + 1] = wn4r; + if (nwh == 4) { + wk1r = w[nw0 + 4]; + wk1i = w[nw0 + 5]; + w[nw1 + 2] = wk1r; + w[nw1 + 3] = wk1i; + } else if (nwh > 4) { + wk1r = w[nw0 + 4]; + wk3r = w[nw0 + 6]; + w[nw1 + 2] = 0.5 / wk1r; + w[nw1 + 3] = 0.5 / wk3r; + for (j = 4; j < nwh; j += 4) { + wk1r = w[nw0 + 2 * j]; + wk1i = w[nw0 + 2 * j + 1]; + wk3r = w[nw0 + 2 * j + 2]; + wk3i = w[nw0 + 2 * j + 3]; + w[nw1 + j] = wk1r; + w[nw1 + j + 1] = wk1i; + w[nw1 + j + 2] = wk3r; + w[nw1 + j + 3] = wk3i; + } + } + nw0 = nw1; + } + } +} + + +void makeipt(int nw, int *ip) { + int j, l, m, m2, p, q; + + ip[2] = 0; + ip[3] = 16; + m = 2; + for (l = nw; l > 32; l >>= 2) { + m2 = m << 1; + q = m2 << 3; + for (j = m; j < m2; j++) { + p = ip[j] << 2; + ip[m + j] = p; + ip[m2 + j] = p + q; + } + m = m2; + } +} + + +void makect(int nc, int *ip, double *c) { + int j, nch; + double delta; + + ip[1] = nc; + if (nc > 1) { + nch = nc >> 1; + delta = atan(1.0) / nch; + c[0] = cos(delta * nch); + c[nch] = 0.5 * c[0]; + for (j = 1; j < nch; j++) { + c[j] = 0.5 * cos(delta * j); + c[nc - j] = 0.5 * sin(delta * j); + } + } +} + + +/* -------- child routines -------- */ + + +#ifdef USE_CDFT_PTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 8192 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 65536 +#endif +#include +#include +#include +#define cdft_thread_t pthread_t +#define cdft_thread_create(thp, func, argp) \ + { \ + if (pthread_create(thp, NULL, func, (void *)argp) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + if (pthread_join(th, NULL) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#endif /* USE_CDFT_PTHREADS */ + + +#ifdef USE_CDFT_WINTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 32768 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 524288 +#endif +#include +#include +#include +#define cdft_thread_t HANDLE +#define cdft_thread_create(thp, func, argp) \ + { \ + DWORD thid; \ + *(thp) = CreateThread( \ + NULL, 0, (LPTHREAD_START_ROUTINE)func, (LPVOID)argp, 0, &thid); \ + if (*(thp) == 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + WaitForSingleObject(th, INFINITE); \ + CloseHandle(th); \ + } +#endif /* USE_CDFT_WINTHREADS */ + + +void cftfsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2(int n, int *ip, double *a); + void bitrv216(double *a); + void bitrv208(double *a); + void cftf1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftf040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftf1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216(a); + } else { + cftf081(a, w); + bitrv208(a); + } + } else if (n == 8) { + cftf040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void cftbsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2conj(int n, int *ip, double *a); + void bitrv216neg(double *a); + void bitrv208neg(double *a); + void cftb1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftb040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftb1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2conj(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216neg(a); + } else { + cftf081(a, w); + bitrv208neg(a); + } + } else if (n == 8) { + cftb040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void bitrv2(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } +} + + +void bitrv2conj(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += nm; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } +} + + +void bitrv216(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x7r, x7i, x8r, x8i, + x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, x14r, x14i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + a[2] = x8r; + a[3] = x8i; + a[4] = x4r; + a[5] = x4i; + a[6] = x12r; + a[7] = x12i; + a[8] = x2r; + a[9] = x2i; + a[10] = x10r; + a[11] = x10i; + a[14] = x14r; + a[15] = x14i; + a[16] = x1r; + a[17] = x1i; + a[20] = x5r; + a[21] = x5i; + a[22] = x13r; + a[23] = x13i; + a[24] = x3r; + a[25] = x3i; + a[26] = x11r; + a[27] = x11i; + a[28] = x7r; + a[29] = x7i; +} + + +void bitrv216neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i, + x8r, x8i, x9r, x9i, x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, + x14r, x14i, x15r, x15i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x9r = a[18]; + x9i = a[19]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + x15r = a[30]; + x15i = a[31]; + a[2] = x15r; + a[3] = x15i; + a[4] = x7r; + a[5] = x7i; + a[6] = x11r; + a[7] = x11i; + a[8] = x3r; + a[9] = x3i; + a[10] = x13r; + a[11] = x13i; + a[12] = x5r; + a[13] = x5i; + a[14] = x9r; + a[15] = x9i; + a[16] = x1r; + a[17] = x1i; + a[18] = x14r; + a[19] = x14i; + a[20] = x6r; + a[21] = x6i; + a[22] = x10r; + a[23] = x10i; + a[24] = x2r; + a[25] = x2i; + a[26] = x12r; + a[27] = x12i; + a[28] = x4r; + a[29] = x4i; + a[30] = x8r; + a[31] = x8i; +} + + +void bitrv208(double *a) { + double x1r, x1i, x3r, x3i, x4r, x4i, x6r, x6i; + + x1r = a[2]; + x1i = a[3]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x6r = a[12]; + x6i = a[13]; + a[2] = x4r; + a[3] = x4i; + a[6] = x6r; + a[7] = x6i; + a[8] = x1r; + a[9] = x1i; + a[12] = x3r; + a[13] = x3i; +} + + +void bitrv208neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + a[2] = x7r; + a[3] = x7i; + a[4] = x3r; + a[5] = x3i; + a[6] = x5r; + a[7] = x5i; + a[8] = x1r; + a[9] = x1i; + a[10] = x6r; + a[11] = x6i; + a[12] = x2r; + a[13] = x2i; + a[14] = x4r; + a[15] = x4i; +} + + +void cftf1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = a[j + 3] + a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = a[j + 3] - a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = a[j0 - 1] + a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = a[j0 - 1] + a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i + x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = a[j0 + 3] + a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = a[j0 + 3] - a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i + x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +void cftb1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = -a[1] - a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = -a[1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j2] = x1r + x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r - x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = -a[j + 1] - a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = -a[j + 1] + a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = -a[j + 3] - a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = -a[j + 3] + a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i - x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = -a[j0 - 1] - a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = -a[j0 - 1] - a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i - x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = -a[j0 + 3] - a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = -a[j0 + 3] + a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i - x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +#ifdef USE_CDFT_THREADS +struct cdft_arg_st { + int n0; + int n; + double *a; + int nw; + double *w; +}; +typedef struct cdft_arg_st cdft_arg_t; + + +void cftrec4_th(int n, double *a, int nw, double *w) { + void *cftrec1_th(void *p); + void *cftrec2_th(void *p); + int i, idiv4, m, nthread; + cdft_thread_t th[4]; + cdft_arg_t ag[4]; + + nthread = 2; + idiv4 = 0; + m = n >> 1; + if (n > CDFT_4THREADS_BEGIN_N) { + nthread = 4; + idiv4 = 1; + m >>= 1; + } + for (i = 0; i < nthread; i++) { + ag[i].n0 = n; + ag[i].n = m; + ag[i].a = &a[i * m]; + ag[i].nw = nw; + ag[i].w = w; + if (i != idiv4) { + cdft_thread_create(&th[i], cftrec1_th, &ag[i]); + } else { + cdft_thread_create(&th[i], cftrec2_th, &ag[i]); + } + } + for (i = 0; i < nthread; i++) { + cdft_thread_wait(th[i]); + } +} + + +void *cftrec1_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + m = n0; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} + + +void *cftrec2_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl2(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + k = 1; + m = n0; + while (m > 512) { + m >>= 2; + k <<= 2; + cftmdl2(m, &a[n - m], &w[nw - m]); + } + cftleaf(m, 0, &a[n - m], nw, w); + k >>= 1; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} +#endif /* USE_CDFT_THREADS */ + + +void cftrec4(int n, double *a, int nw, double *w) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m; + + m = n; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } +} + + +int cfttree(int n, int j, int k, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + int i, isplt, m; + + if ((k & 3) != 0) { + isplt = k & 1; + if (isplt != 0) { + cftmdl1(n, &a[j - n], &w[nw - (n >> 1)]); + } else { + cftmdl2(n, &a[j - n], &w[nw - n]); + } + } else { + m = n; + for (i = k; (i & 3) == 0; i >>= 2) { + m <<= 2; + } + isplt = i & 1; + if (isplt != 0) { + while (m > 128) { + cftmdl1(m, &a[j - m], &w[nw - (m >> 1)]); + m >>= 2; + } + } else { + while (m > 128) { + cftmdl2(m, &a[j - m], &w[nw - m]); + m >>= 2; + } + } + } + return isplt; +} + + +void cftleaf(int n, int isplt, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 512) { + cftmdl1(128, a, &w[nw - 64]); + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + cftmdl2(128, &a[128], &w[nw - 128]); + cftf161(&a[128], &w[nw - 8]); + cftf162(&a[160], &w[nw - 32]); + cftf161(&a[192], &w[nw - 8]); + cftf162(&a[224], &w[nw - 32]); + cftmdl1(128, &a[256], &w[nw - 64]); + cftf161(&a[256], &w[nw - 8]); + cftf162(&a[288], &w[nw - 32]); + cftf161(&a[320], &w[nw - 8]); + cftf161(&a[352], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(128, &a[384], &w[nw - 64]); + cftf161(&a[480], &w[nw - 8]); + } else { + cftmdl2(128, &a[384], &w[nw - 128]); + cftf162(&a[480], &w[nw - 32]); + } + cftf161(&a[384], &w[nw - 8]); + cftf162(&a[416], &w[nw - 32]); + cftf161(&a[448], &w[nw - 8]); + } else { + cftmdl1(64, a, &w[nw - 32]); + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + cftmdl2(64, &a[64], &w[nw - 64]); + cftf081(&a[64], &w[nw - 8]); + cftf082(&a[80], &w[nw - 8]); + cftf081(&a[96], &w[nw - 8]); + cftf082(&a[112], &w[nw - 8]); + cftmdl1(64, &a[128], &w[nw - 32]); + cftf081(&a[128], &w[nw - 8]); + cftf082(&a[144], &w[nw - 8]); + cftf081(&a[160], &w[nw - 8]); + cftf081(&a[176], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(64, &a[192], &w[nw - 32]); + cftf081(&a[240], &w[nw - 8]); + } else { + cftmdl2(64, &a[192], &w[nw - 64]); + cftf082(&a[240], &w[nw - 8]); + } + cftf081(&a[192], &w[nw - 8]); + cftf082(&a[208], &w[nw - 8]); + cftf081(&a[224], &w[nw - 8]); + } +} + + +void cftmdl1(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + k = 0; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + } + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); +} + + +void cftmdl2(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, kr, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y2r, y2i; + + mh = n >> 3; + m = 2 * mh; + wn4r = w[1]; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] - a[j2 + 1]; + x0i = a[1] + a[j2]; + x1r = a[0] + a[j2 + 1]; + x1i = a[1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wn4r * (x2r - x2i); + y0i = wn4r * (x2i + x2r); + a[0] = x0r + y0r; + a[1] = x0i + y0i; + a[j1] = x0r - y0r; + a[j1 + 1] = x0i - y0i; + y0r = wn4r * (x3r - x3i); + y0i = wn4r * (x3i + x3r); + a[j2] = x1r - y0i; + a[j2 + 1] = x1i + y0r; + a[j3] = x1r + y0i; + a[j3 + 1] = x1i - y0r; + k = 0; + kr = 2 * m; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + kr -= 4; + wd1i = w[kr]; + wd1r = w[kr + 1]; + wd3i = w[kr + 2]; + wd3r = w[kr + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] - a[j2 + 1]; + x0i = a[j + 1] + a[j2]; + x1r = a[j] + a[j2 + 1]; + x1i = a[j + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wd1r * x2r - wd1i * x2i; + y2i = wd1r * x2i + wd1i * x2r; + a[j] = y0r + y2r; + a[j + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk3r * x1r + wk3i * x1i; + y0i = wk3r * x1i - wk3i * x1r; + y2r = wd3r * x3r + wd3i * x3i; + y2i = wd3r * x3i - wd3i * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wd1i * x0r - wd1r * x0i; + y0i = wd1i * x0i + wd1r * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wd3i * x1r + wd3r * x1i; + y0i = wd3i * x1i - wd3r * x1r; + y2r = wk3i * x3r + wk3r * x3i; + y2i = wk3i * x3i - wk3r * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + } + wk1r = w[m]; + wk1i = w[m + 1]; + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk1i * x1r - wk1r * x1i; + y0i = wk1i * x1i + wk1r * x1r; + y2r = wk1r * x3r - wk1i * x3i; + y2i = wk1r * x3i + wk1i * x3r; + a[j2] = y0r - y2r; + a[j2 + 1] = y0i - y2i; + a[j3] = y0r + y2r; + a[j3 + 1] = y0i + y2i; +} + + +void cftfx41(int n, double *a, int nw, double *w) { + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 128) { + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + } else { + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + } +} + + +void cftf161(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, + y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i, + y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, y13r, y13i, + y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + x0r = a[0] + a[16]; + x0i = a[1] + a[17]; + x1r = a[0] - a[16]; + x1i = a[1] - a[17]; + x2r = a[8] + a[24]; + x2i = a[9] + a[25]; + x3r = a[8] - a[24]; + x3i = a[9] - a[25]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y4r = x0r - x2r; + y4i = x0i - x2i; + y8r = x1r - x3i; + y8i = x1i + x3r; + y12r = x1r + x3i; + y12i = x1i - x3r; + x0r = a[2] + a[18]; + x0i = a[3] + a[19]; + x1r = a[2] - a[18]; + x1i = a[3] - a[19]; + x2r = a[10] + a[26]; + x2i = a[11] + a[27]; + x3r = a[10] - a[26]; + x3i = a[11] - a[27]; + y1r = x0r + x2r; + y1i = x0i + x2i; + y5r = x0r - x2r; + y5i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y9r = wk1r * x0r - wk1i * x0i; + y9i = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y13r = wk1i * x0r - wk1r * x0i; + y13i = wk1i * x0i + wk1r * x0r; + x0r = a[4] + a[20]; + x0i = a[5] + a[21]; + x1r = a[4] - a[20]; + x1i = a[5] - a[21]; + x2r = a[12] + a[28]; + x2i = a[13] + a[29]; + x3r = a[12] - a[28]; + x3i = a[13] - a[29]; + y2r = x0r + x2r; + y2i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y10r = wn4r * (x0r - x0i); + y10i = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + y14r = wn4r * (x0r + x0i); + y14i = wn4r * (x0i - x0r); + x0r = a[6] + a[22]; + x0i = a[7] + a[23]; + x1r = a[6] - a[22]; + x1i = a[7] - a[23]; + x2r = a[14] + a[30]; + x2i = a[15] + a[31]; + x3r = a[14] - a[30]; + x3i = a[15] - a[31]; + y3r = x0r + x2r; + y3i = x0i + x2i; + y7r = x0r - x2r; + y7i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y11r = wk1i * x0r - wk1r * x0i; + y11i = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y15r = wk1r * x0r - wk1i * x0i; + y15i = wk1r * x0i + wk1i * x0r; + x0r = y12r - y14r; + x0i = y12i - y14i; + x1r = y12r + y14r; + x1i = y12i + y14i; + x2r = y13r - y15r; + x2i = y13i - y15i; + x3r = y13r + y15r; + x3i = y13i + y15i; + a[24] = x0r + x2r; + a[25] = x0i + x2i; + a[26] = x0r - x2r; + a[27] = x0i - x2i; + a[28] = x1r - x3i; + a[29] = x1i + x3r; + a[30] = x1r + x3i; + a[31] = x1i - x3r; + x0r = y8r + y10r; + x0i = y8i + y10i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + x3r = y9r - y11r; + x3i = y9i - y11i; + a[16] = x0r + x2r; + a[17] = x0i + x2i; + a[18] = x0r - x2r; + a[19] = x0i - x2i; + a[20] = x1r - x3i; + a[21] = x1i + x3r; + a[22] = x1r + x3i; + a[23] = x1i - x3r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + x0r = y5r + y7i; + x0i = y5i - y7r; + x3r = wn4r * (x0r - x0i); + x3i = wn4r * (x0i + x0r); + x0r = y4r - y6i; + x0i = y4i + y6r; + x1r = y4r + y6i; + x1i = y4i - y6r; + a[8] = x0r + x2r; + a[9] = x0i + x2i; + a[10] = x0r - x2r; + a[11] = x0i - x2i; + a[12] = x1r - x3i; + a[13] = x1i + x3r; + a[14] = x1r + x3i; + a[15] = x1i - x3r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + x3r = y1r - y3r; + x3i = y1i - y3i; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x0r - x2r; + a[3] = x0i - x2i; + a[4] = x1r - x3i; + a[5] = x1i + x3r; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftf162(double *a, double *w) { + double wn4r, wk1r, wk1i, wk2r, wk2i, wk3r, wk3i, x0r, x0i, x1r, x1i, x2r, + x2i, y0r, y0i, y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, + y6i, y7r, y7i, y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, + y13r, y13i, y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[4]; + wk1i = w[5]; + wk3r = w[6]; + wk3i = -w[7]; + wk2r = w[8]; + wk2i = w[9]; + x1r = a[0] - a[17]; + x1i = a[1] + a[16]; + x0r = a[8] - a[25]; + x0i = a[9] + a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y0r = x1r + x2r; + y0i = x1i + x2i; + y4r = x1r - x2r; + y4i = x1i - x2i; + x1r = a[0] + a[17]; + x1i = a[1] - a[16]; + x0r = a[8] + a[25]; + x0i = a[9] - a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y8r = x1r - x2i; + y8i = x1i + x2r; + y12r = x1r + x2i; + y12i = x1i - x2r; + x0r = a[2] - a[19]; + x0i = a[3] + a[18]; + x1r = wk1r * x0r - wk1i * x0i; + x1i = wk1r * x0i + wk1i * x0r; + x0r = a[10] - a[27]; + x0i = a[11] + a[26]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y1r = x1r + x2r; + y1i = x1i + x2i; + y5r = x1r - x2r; + y5i = x1i - x2i; + x0r = a[2] + a[19]; + x0i = a[3] - a[18]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[10] + a[27]; + x0i = a[11] - a[26]; + x2r = wk1r * x0r + wk1i * x0i; + x2i = wk1r * x0i - wk1i * x0r; + y9r = x1r - x2r; + y9i = x1i - x2i; + y13r = x1r + x2r; + y13i = x1i + x2i; + x0r = a[4] - a[21]; + x0i = a[5] + a[20]; + x1r = wk2r * x0r - wk2i * x0i; + x1i = wk2r * x0i + wk2i * x0r; + x0r = a[12] - a[29]; + x0i = a[13] + a[28]; + x2r = wk2i * x0r - wk2r * x0i; + x2i = wk2i * x0i + wk2r * x0r; + y2r = x1r + x2r; + y2i = x1i + x2i; + y6r = x1r - x2r; + y6i = x1i - x2i; + x0r = a[4] + a[21]; + x0i = a[5] - a[20]; + x1r = wk2i * x0r - wk2r * x0i; + x1i = wk2i * x0i + wk2r * x0r; + x0r = a[12] + a[29]; + x0i = a[13] - a[28]; + x2r = wk2r * x0r - wk2i * x0i; + x2i = wk2r * x0i + wk2i * x0r; + y10r = x1r - x2r; + y10i = x1i - x2i; + y14r = x1r + x2r; + y14i = x1i + x2i; + x0r = a[6] - a[23]; + x0i = a[7] + a[22]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[14] - a[31]; + x0i = a[15] + a[30]; + x2r = wk1i * x0r - wk1r * x0i; + x2i = wk1i * x0i + wk1r * x0r; + y3r = x1r + x2r; + y3i = x1i + x2i; + y7r = x1r - x2r; + y7i = x1i - x2i; + x0r = a[6] + a[23]; + x0i = a[7] - a[22]; + x1r = wk1i * x0r + wk1r * x0i; + x1i = wk1i * x0i - wk1r * x0r; + x0r = a[14] + a[31]; + x0i = a[15] - a[30]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y11r = x1r + x2r; + y11i = x1i + x2i; + y15r = x1r - x2r; + y15i = x1i - x2i; + x1r = y0r + y2r; + x1i = y0i + y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + a[0] = x1r + x2r; + a[1] = x1i + x2i; + a[2] = x1r - x2r; + a[3] = x1i - x2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r - y3r; + x2i = y1i - y3i; + a[4] = x1r - x2i; + a[5] = x1i + x2r; + a[6] = x1r + x2i; + a[7] = x1i - x2r; + x1r = y4r - y6i; + x1i = y4i + y6r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[8] = x1r + x2r; + a[9] = x1i + x2i; + a[10] = x1r - x2r; + a[11] = x1i - x2i; + x1r = y4r + y6i; + x1i = y4i - y6r; + x0r = y5r + y7i; + x0i = y5i - y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[12] = x1r - x2i; + a[13] = x1i + x2r; + a[14] = x1r + x2i; + a[15] = x1i - x2r; + x1r = y8r + y10r; + x1i = y8i + y10i; + x2r = y9r - y11r; + x2i = y9i - y11i; + a[16] = x1r + x2r; + a[17] = x1i + x2i; + a[18] = x1r - x2r; + a[19] = x1i - x2i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + a[20] = x1r - x2i; + a[21] = x1i + x2r; + a[22] = x1r + x2i; + a[23] = x1i - x2r; + x1r = y12r - y14i; + x1i = y12i + y14r; + x0r = y13r + y15i; + x0i = y13i - y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[24] = x1r + x2r; + a[25] = x1i + x2i; + a[26] = x1r - x2r; + a[27] = x1i - x2i; + x1r = y12r + y14i; + x1i = y12i - y14r; + x0r = y13r - y15i; + x0i = y13i + y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[28] = x1r - x2i; + a[29] = x1i + x2r; + a[30] = x1r + x2i; + a[31] = x1i - x2r; +} + + +void cftf081(double *a, double *w) { + double wn4r, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, + y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + x0r = a[0] + a[8]; + x0i = a[1] + a[9]; + x1r = a[0] - a[8]; + x1i = a[1] - a[9]; + x2r = a[4] + a[12]; + x2i = a[5] + a[13]; + x3r = a[4] - a[12]; + x3i = a[5] - a[13]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y2r = x0r - x2r; + y2i = x0i - x2i; + y1r = x1r - x3i; + y1i = x1i + x3r; + y3r = x1r + x3i; + y3i = x1i - x3r; + x0r = a[2] + a[10]; + x0i = a[3] + a[11]; + x1r = a[2] - a[10]; + x1i = a[3] - a[11]; + x2r = a[6] + a[14]; + x2i = a[7] + a[15]; + x3r = a[6] - a[14]; + x3i = a[7] - a[15]; + y4r = x0r + x2r; + y4i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + x2r = x1r + x3i; + x2i = x1i - x3r; + y5r = wn4r * (x0r - x0i); + y5i = wn4r * (x0r + x0i); + y7r = wn4r * (x2r - x2i); + y7i = wn4r * (x2r + x2i); + a[8] = y1r + y5r; + a[9] = y1i + y5i; + a[10] = y1r - y5r; + a[11] = y1i - y5i; + a[12] = y3r - y7i; + a[13] = y3i + y7r; + a[14] = y3r + y7i; + a[15] = y3i - y7r; + a[0] = y0r + y4r; + a[1] = y0i + y4i; + a[2] = y0r - y4r; + a[3] = y0i - y4i; + a[4] = y2r - y6i; + a[5] = y2i + y6r; + a[6] = y2r + y6i; + a[7] = y2i - y6r; +} + + +void cftf082(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + y0r = a[0] - a[9]; + y0i = a[1] + a[8]; + y1r = a[0] + a[9]; + y1i = a[1] - a[8]; + x0r = a[4] - a[13]; + x0i = a[5] + a[12]; + y2r = wn4r * (x0r - x0i); + y2i = wn4r * (x0i + x0r); + x0r = a[4] + a[13]; + x0i = a[5] - a[12]; + y3r = wn4r * (x0r - x0i); + y3i = wn4r * (x0i + x0r); + x0r = a[2] - a[11]; + x0i = a[3] + a[10]; + y4r = wk1r * x0r - wk1i * x0i; + y4i = wk1r * x0i + wk1i * x0r; + x0r = a[2] + a[11]; + x0i = a[3] - a[10]; + y5r = wk1i * x0r - wk1r * x0i; + y5i = wk1i * x0i + wk1r * x0r; + x0r = a[6] - a[15]; + x0i = a[7] + a[14]; + y6r = wk1i * x0r - wk1r * x0i; + y6i = wk1i * x0i + wk1r * x0r; + x0r = a[6] + a[15]; + x0i = a[7] - a[14]; + y7r = wk1r * x0r - wk1i * x0i; + y7i = wk1r * x0i + wk1i * x0r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y4r + y6r; + x1i = y4i + y6i; + a[0] = x0r + x1r; + a[1] = x0i + x1i; + a[2] = x0r - x1r; + a[3] = x0i - x1i; + x0r = y0r - y2r; + x0i = y0i - y2i; + x1r = y4r - y6r; + x1i = y4i - y6i; + a[4] = x0r - x1i; + a[5] = x0i + x1r; + a[6] = x0r + x1i; + a[7] = x0i - x1r; + x0r = y1r - y3i; + x0i = y1i + y3r; + x1r = y5r - y7r; + x1i = y5i - y7i; + a[8] = x0r + x1r; + a[9] = x0i + x1i; + a[10] = x0r - x1r; + a[11] = x0i - x1i; + x0r = y1r + y3i; + x0i = y1i - y3r; + x1r = y5r + y7r; + x1i = y5i + y7i; + a[12] = x0r - x1i; + a[13] = x0i + x1r; + a[14] = x0r + x1i; + a[15] = x0i - x1r; +} + + +void cftf040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r - x3i; + a[3] = x1i + x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftb040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r + x3i; + a[3] = x1i - x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r - x3i; + a[7] = x1i + x3r; +} + + +void cftx020(double *a) { + double x0r, x0i; + + x0r = a[0] - a[2]; + x0i = a[1] - a[3]; + a[0] += a[2]; + a[1] += a[3]; + a[2] = x0r; + a[3] = x0i; +} + + +void rftfsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr - wki * xi; + yi = wkr * xi + wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void rftbsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr + wki * xi; + yi = wkr * xi - wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void dctsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[j] - wkr * a[k]; + a[j] = wkr * a[j] + wki * a[k]; + a[k] = xr; + } + a[m] *= c[0]; +} + + +void dstsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[k] - wkr * a[j]; + a[k] = wkr * a[k] + wki * a[j]; + a[j] = xr; + } + a[m] *= c[0]; +} diff --git a/speechx/speechx/common/frontend/audio/frontend_itf.h b/speechx/speechx/common/frontend/audio/frontend_itf.h index 7913cc7c..3df8fb09 100644 --- a/speechx/speechx/common/frontend/audio/frontend_itf.h +++ b/speechx/speechx/common/frontend/audio/frontend_itf.h @@ -22,13 +22,13 @@ namespace ppspeech { class FrontendInterface { public: // Feed inputs: features(2D saved in 1D) or waveforms(1D). - virtual void Accept(const kaldi::VectorBase& inputs) = 0; + virtual void Accept(const std::vector& inputs) = 0; // Fetch processed data: features or waveforms. // For features(2D saved in 1D), the Matrix is squashed into Vector, // the length of output = feature_row * feature_dim. // For waveforms(1D), samples saved in vector. - virtual bool Read(kaldi::Vector* outputs) = 0; + virtual bool Read(std::vector* outputs) = 0; // Dim is the feature dim. For waveforms(1D), Dim is zero; else is specific, // e.g 80 for fbank. diff --git a/speechx/speechx/common/frontend/audio/mel-computations.cc b/speechx/speechx/common/frontend/audio/mel-computations.cc new file mode 100644 index 00000000..a876368e --- /dev/null +++ b/speechx/speechx/common/frontend/audio/mel-computations.cc @@ -0,0 +1,277 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/mel-computations.cc + +#include "frontend/audio/mel-computations.h" + +#include +#include + +#include "frontend/audio/feature-window.h" + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) { + os << opts.ToString(); + return os; +} + +float MelBanks::VtlnWarpFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float freq) { + /// This computes a VTLN warping function that is not the same as HTK's one, + /// but has similar inputs (this function has the advantage of never + /// producing + /// empty bins). + + /// This function computes a warp function F(freq), defined between low_freq + /// and high_freq inclusive, with the following properties: + /// F(low_freq) == low_freq + /// F(high_freq) == high_freq + /// The function is continuous and piecewise linear with two inflection + /// points. + /// The lower inflection point (measured in terms of the unwarped + /// frequency) is at frequency l, determined as described below. + /// The higher inflection point is at a frequency h, determined as + /// described below. + /// If l <= f <= h, then F(f) = f/vtln_warp_factor. + /// If the higher inflection point (measured in terms of the unwarped + /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + /// Since (by the last point) F(h) == h/vtln_warp_factor, then + /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + /// = vtln_high_cutoff * min(1, vtln_warp_factor). + /// If the lower inflection point (measured in terms of the unwarped + /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + /// = vtln_low_cutoff * max(1, vtln_warp_factor) + + if (freq < low_freq || freq > high_freq) + return freq; // in case this gets called + // for out-of-range frequencies, just return the freq. + + CHECK_GT(vtln_low_cutoff, low_freq); + CHECK_LT(vtln_high_cutoff, high_freq); + + float one = 1.0f; + float l = vtln_low_cutoff * std::max(one, vtln_warp_factor); + float h = vtln_high_cutoff * std::min(one, vtln_warp_factor); + float scale = 1.0f / vtln_warp_factor; + float Fl = scale * l; // F(l); + float Fh = scale * h; // F(h); + CHECK(l > low_freq && h < high_freq); + // slope of left part of the 3-piece linear function + float scale_left = (Fl - low_freq) / (l - low_freq); + // [slope of center part is just "scale"] + + // slope of right part of the 3-piece linear function + float scale_right = (high_freq - Fh) / (high_freq - h); + + if (freq < l) { + return low_freq + scale_left * (freq - low_freq); + } else if (freq < h) { + return scale * freq; + } else { // freq >= h + return high_freq + scale_right * (freq - high_freq); + } +} + +float MelBanks::VtlnWarpMelFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float mel_freq) { + return MelScale(VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + vtln_warp_factor, + InverseMelScale(mel_freq))); +} + +MelBanks::MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor) + : htk_mode_(opts.htk_mode) { + int32_t num_bins = opts.num_bins; + if (num_bins < 3) LOG(FATAL) << "Must have at least 3 mel bins"; + + float sample_freq = frame_opts.samp_freq; + int32_t window_length_padded = frame_opts.PaddedWindowSize(); + CHECK_EQ(window_length_padded % 2, 0); + + int32_t num_fft_bins = window_length_padded / 2; + float nyquist = 0.5f * sample_freq; + + float low_freq = opts.low_freq, high_freq; + if (opts.high_freq > 0.0f) + high_freq = opts.high_freq; + else + high_freq = nyquist + opts.high_freq; + + if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f || + high_freq > nyquist || high_freq <= low_freq) { + LOG(FATAL) << "Bad values in options: low-freq " << low_freq + << " and high-freq " << high_freq << " vs. nyquist " + << nyquist; + } + + float fft_bin_width = sample_freq / window_length_padded; + // fft-bin width [think of it as Nyquist-freq / half-window-length] + + float mel_low_freq = MelScale(low_freq); + float mel_high_freq = MelScale(high_freq); + + debug_ = opts.debug_mel; + + // divide by num_bins+1 in next line because of end-effects where the bins + // spread out to the sides. + float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); + + float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high; + if (vtln_high < 0.0f) { + vtln_high += nyquist; + } + + if (vtln_warp_factor != 1.0f && + (vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq || + vtln_high <= 0.0f || vtln_high >= high_freq || + vtln_high <= vtln_low)) { + LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low + << " and vtln-high " << vtln_high << ", versus " + << "low-freq " << low_freq << " and high-freq " << high_freq; + } + + bins_.resize(num_bins); + center_freqs_.resize(num_bins); + + for (int32_t bin = 0; bin < num_bins; ++bin) { + float left_mel = mel_low_freq + bin * mel_freq_delta, + center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, + right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; + + if (vtln_warp_factor != 1.0f) { + left_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + left_mel); + center_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + center_mel); + right_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + right_mel); + } + center_freqs_[bin] = InverseMelScale(center_mel); + + // this_bin will be a vector of coefficients that is only + // nonzero where this mel bin is active. + std::vector this_bin(num_fft_bins); + + int32_t first_index = -1, last_index = -1; + for (int32_t i = 0; i < num_fft_bins; ++i) { + float freq = (fft_bin_width * i); // Center frequency of this fft + // bin. + float mel = MelScale(freq); + if (mel > left_mel && mel < right_mel) { + float weight; + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else + weight = (right_mel - mel) / (right_mel - center_mel); + this_bin[i] = weight; + if (first_index == -1) first_index = i; + last_index = i; + } + } + CHECK(first_index != -1 && last_index >= first_index && + "You may have set num_mel_bins too large."); + + bins_[bin].first = first_index; + int32_t size = last_index + 1 - first_index; + bins_[bin].second.insert(bins_[bin].second.end(), + this_bin.begin() + first_index, + this_bin.begin() + first_index + size); + + // Replicate a bug in HTK, for testing purposes. + if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) { + bins_[bin].second[0] = 0.0; + } + } // for (int32_t bin = 0; bin < num_bins; ++bin) { + + if (debug_) { + std::ostringstream os; + for (size_t i = 0; i < bins_.size(); i++) { + os << "bin " << i << ", offset = " << bins_[i].first << ", vec = "; + for (auto k : bins_[i].second) os << k << ", "; + os << "\n"; + } + LOG(INFO) << os.str(); + } +} + +// "power_spectrum" contains fft energies. +void MelBanks::Compute(const float *power_spectrum, + float *mel_energies_out) const { + int32_t num_bins = bins_.size(); + + for (int32_t i = 0; i < num_bins; i++) { + int32_t offset = bins_[i].first; + const auto &v = bins_[i].second; + float energy = 0; + for (int32_t k = 0; k != v.size(); ++k) { + energy += v[k] * power_spectrum[k + offset]; + } + + // HTK-like flooring- for testing purposes (we prefer dither) + if (htk_mode_ && energy < 1.0) { + energy = 1.0; + } + + mel_energies_out[i] = energy; + + // The following assert was added due to a problem with OpenBlas that + // we had at one point (it was a bug in that library). Just to detect + // it early. + CHECK_EQ(energy, energy); // check that energy is not nan + } + + if (debug_) { + fprintf(stderr, "MEL BANKS:\n"); + for (int32_t i = 0; i < num_bins; i++) + fprintf(stderr, " %f", mel_energies_out[i]); + fprintf(stderr, "\n"); + } +} + +} // namespace knf diff --git a/speechx/speechx/common/frontend/audio/mel-computations.h b/speechx/speechx/common/frontend/audio/mel-computations.h new file mode 100644 index 00000000..3f1b9678 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/mel-computations.h @@ -0,0 +1,120 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// This file is copied/modified from kaldi/src/feat/mel-computations.h +#ifndef KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ +#define KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ + +#include +#include + +#include "frontend/audio/feature-window.h" + +namespace knf { + +struct MelBanksOptions { + int32_t num_bins = 25; // e.g. 25; number of triangular bins + float low_freq = 20; // e.g. 20; lower frequency cutoff + + // an upper frequency cutoff; 0 -> no cutoff, negative + // ->added to the Nyquist frequency to get the cutoff. + float high_freq = 0; + + float vtln_low = 100; // vtln lower cutoff of warping function. + + // vtln upper cutoff of warping function: if negative, added + // to the Nyquist frequency to get the cutoff. + float vtln_high = -500; + + bool debug_mel = false; + // htk_mode is a "hidden" config, it does not show up on command line. + // Enables more exact compatibility with HTK, for testing purposes. Affects + // mel-energy flooring and reproduces a bug in HTK. + bool htk_mode = false; + + std::string ToString() const { + std::ostringstream os; + os << "num_bins: " << num_bins << "\n"; + os << "low_freq: " << low_freq << "\n"; + os << "high_freq: " << high_freq << "\n"; + os << "vtln_low: " << vtln_low << "\n"; + os << "vtln_high: " << vtln_high << "\n"; + os << "debug_mel: " << debug_mel << "\n"; + os << "htk_mode: " << htk_mode << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts); + +class MelBanks { + public: + static inline float InverseMelScale(float mel_freq) { + return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + } + + static inline float MelScale(float freq) { + return 1127.0f * logf(1.0f + freq / 700.0f); + } + + static float VtlnWarpFreq( + float vtln_low_cutoff, + float vtln_high_cutoff, // discontinuities in warp func + float low_freq, + float high_freq, // upper+lower frequency cutoffs in + // the mel computation + float vtln_warp_factor, + float freq); + + static float VtlnWarpMelFreq(float vtln_low_cutoff, + float vtln_high_cutoff, + float low_freq, + float high_freq, + float vtln_warp_factor, + float mel_freq); + + // TODO(fangjun): Remove vtln_warp_factor + MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor); + + /// Compute Mel energies (note: not log energies). + /// At input, "fft_energies" contains the FFT energies (not log). + /// + /// @param fft_energies 1-D array of size num_fft_bins/2+1 + /// @param mel_energies_out 1-D array of size num_mel_bins + void Compute(const float *fft_energies, float *mel_energies_out) const; + + int32_t NumBins() const { return bins_.size(); } + + private: + // center frequencies of bins, numbered from 0 ... num_bins-1. + // Needed by GetCenterFreqs(). + std::vector center_freqs_; + + // the "bins_" vector is a vector, one for each bin, of a pair: + // (the first nonzero fft-bin), (the vector of weights). + std::vector>> bins_; + + // TODO(fangjun): Remove debug_ and htk_mode_ + bool debug_; + bool htk_mode_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ diff --git a/speechx/speechx/common/frontend/audio/rfft.cc b/speechx/speechx/common/frontend/audio/rfft.cc new file mode 100644 index 00000000..84fbc9c4 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/rfft.cc @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/audio/rfft.h" + +#include +#include + +#include "base/log.h" + +// see fftsg.c +#ifdef __cplusplus +extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w); +#else +void rdft(int n, int isgn, double *a, int *ip, double *w); +#endif + +namespace knf { +class Rfft::RfftImpl { + public: + explicit RfftImpl(int32_t n) : n_(n), ip_(2 + std::sqrt(n / 2)), w_(n / 2) { + CHECK_EQ(n & (n - 1), 0); + } + + void Compute(float *in_out) { + std::vector d(in_out, in_out + n_); + + Compute(d.data()); + + std::copy(d.begin(), d.end(), in_out); + } + + void Compute(double *in_out) { + // 1 means forward fft + rdft(n_, 1, in_out, ip_.data(), w_.data()); + } + + private: + int32_t n_; + std::vector ip_; + std::vector w_; +}; + +Rfft::Rfft(int32_t n) : impl_(std::make_unique(n)) {} + +Rfft::~Rfft() = default; + +void Rfft::Compute(float *in_out) { impl_->Compute(in_out); } +void Rfft::Compute(double *in_out) { impl_->Compute(in_out); } + +} // namespace knf diff --git a/speechx/speechx/common/frontend/audio/rfft.h b/speechx/speechx/common/frontend/audio/rfft.h new file mode 100644 index 00000000..52da2626 --- /dev/null +++ b/speechx/speechx/common/frontend/audio/rfft.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KALDI_NATIVE_FBANK_CSRC_RFFT_H_ +#define KALDI_NATIVE_FBANK_CSRC_RFFT_H_ + +#include + +namespace knf { + +// n-point Real discrete Fourier transform +// where n is a power of 2. n >= 2 +// +// R[k] = sum_j=0^n-1 in[j]*cos(2*pi*j*k/n), 0<=k<=n/2 +// I[k] = sum_j=0^n-1 in[j]*sin(2*pi*j*k/n), 0 impl_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_RFFT_H_ From ee7c266f130182d1bac4db378932784eec8b48f6 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 11 Jan 2023 20:14:48 +0800 Subject: [PATCH 06/50] [speechx] rm openblas && refactor kaldi-matrix, kaldi-vector (#2824) * rm openblas && refactor kaldi-matrix kaldi-vector --- speechx/CMakeLists.txt | 3 - .../ctc_prefix_beam_search_decoder_main.cc | 2 +- speechx/speechx/asr/nnet/decodable.h | 2 +- speechx/speechx/asr/nnet/nnet_itf.h | 1 - speechx/speechx/asr/nnet/nnet_producer.cc | 2 +- speechx/speechx/asr/nnet/nnet_producer.h | 2 +- speechx/speechx/asr/nnet/u2_nnet.h | 2 +- speechx/speechx/asr/nnet/u2_nnet_main.cc | 4 +- speechx/speechx/asr/recognizer/CMakeLists.txt | 2 +- .../speechx/asr/recognizer/u2_recognizer.h | 2 +- .../asr/recognizer/u2_recognizer_main.cc | 2 +- .../recognizer/u2_recognizer_thread_main.cc | 2 +- speechx/speechx/common/CMakeLists.txt | 2 + .../speechx/common/frontend/CMakeLists.txt | 28 +- .../common/frontend/{audio => }/assembler.cc | 2 +- .../common/frontend/{audio => }/assembler.h | 2 +- .../common/frontend/audio/CMakeLists.txt | 27 - .../frontend/{audio => }/audio_cache.cc | 2 +- .../common/frontend/{audio => }/audio_cache.h | 2 +- .../common/frontend/{audio => }/cmvn.cc | 2 +- .../common/frontend/{audio => }/cmvn.h | 3 +- .../{audio => }/compute_fbank_main.cc | 14 +- .../compute_linear_spectrogram_main.cc | 0 .../common/frontend/{audio => }/data_cache.h | 2 +- .../common/frontend/{audio => }/db_norm.cc | 0 .../common/frontend/{audio => }/db_norm.h | 0 .../common/frontend/{audio => }/fbank.cc | 0 .../common/frontend/{audio => }/fbank.h | 4 +- .../frontend/{audio => }/feature-fbank.cc | 4 +- .../frontend/{audio => }/feature-fbank.h | 6 +- .../frontend/{audio => }/feature-functions.cc | 2 +- .../frontend/{audio => }/feature-functions.h | 0 .../frontend/{audio => }/feature-window.cc | 2 +- .../frontend/{audio => }/feature-window.h | 0 .../frontend/{audio => }/feature_cache.cc | 2 +- .../frontend/{audio => }/feature_cache.h | 2 +- .../frontend/{audio => }/feature_common.h | 4 +- .../frontend/{audio => }/feature_common_inl.h | 0 .../frontend/{audio => }/feature_pipeline.cc | 2 +- .../frontend/{audio => }/feature_pipeline.h | 14 +- .../common/frontend/{audio => }/fftsg.c | 0 .../frontend/{audio => }/frontend_itf.h | 2 +- .../{audio => }/linear_spectrogram.cc | 0 .../frontend/{audio => }/linear_spectrogram.h | 0 .../frontend/{audio => }/mel-computations.cc | 4 +- .../frontend/{audio => }/mel-computations.h | 2 +- .../common/frontend/{audio => }/mfcc.cc | 0 .../common/frontend/{audio => }/mfcc.h | 0 .../common/frontend/{audio => }/normalizer.h | 3 +- .../common/frontend/{audio => }/rfft.cc | 2 +- .../common/frontend/{audio => }/rfft.h | 0 .../feat => common/frontend}/wave-reader.cc | 2 +- .../feat => common/frontend}/wave-reader.h | 0 speechx/speechx/common/matrix/CMakeLists.txt | 7 + .../matrix/kaldi-matrix-inl.h | 3 +- .../{kaldi => common}/matrix/kaldi-matrix.cc | 168 +- .../{kaldi => common}/matrix/kaldi-matrix.h | 387 +--- .../matrix/kaldi-vector-inl.h | 16 +- .../{kaldi => common}/matrix/kaldi-vector.cc | 1192 +++++------- speechx/speechx/common/matrix/kaldi-vector.h | 345 ++++ .../{kaldi => common}/matrix/matrix-common.h | 21 +- speechx/speechx/kaldi/CMakeLists.txt | 2 - speechx/speechx/kaldi/feat/CMakeLists.txt | 20 - speechx/speechx/kaldi/feat/cmvn.cc | 183 -- speechx/speechx/kaldi/feat/cmvn.h | 75 - .../speechx/kaldi/feat/feature-common-inl.h | 99 - speechx/speechx/kaldi/feat/feature-common.h | 176 -- speechx/speechx/kaldi/feat/feature-fbank.cc | 125 -- speechx/speechx/kaldi/feat/feature-fbank.h | 149 -- .../speechx/kaldi/feat/feature-functions.cc | 362 ---- .../speechx/kaldi/feat/feature-functions.h | 204 -- speechx/speechx/kaldi/feat/feature-mfcc.cc | 157 -- speechx/speechx/kaldi/feat/feature-mfcc.h | 154 -- speechx/speechx/kaldi/feat/feature-plp.cc | 191 -- speechx/speechx/kaldi/feat/feature-plp.h | 176 -- .../speechx/kaldi/feat/feature-spectrogram.cc | 82 - .../speechx/kaldi/feat/feature-spectrogram.h | 117 -- speechx/speechx/kaldi/feat/feature-window.cc | 222 --- speechx/speechx/kaldi/feat/feature-window.h | 223 --- .../speechx/kaldi/feat/mel-computations.cc | 340 ---- speechx/speechx/kaldi/feat/mel-computations.h | 171 -- .../speechx/kaldi/feat/online-feature-itf.h | 125 -- speechx/speechx/kaldi/feat/online-feature.cc | 679 ------- speechx/speechx/kaldi/feat/online-feature.h | 632 ------- speechx/speechx/kaldi/feat/pitch-functions.cc | 1667 ----------------- speechx/speechx/kaldi/feat/pitch-functions.h | 450 ----- speechx/speechx/kaldi/feat/resample.cc | 377 ---- speechx/speechx/kaldi/feat/resample.h | 287 --- speechx/speechx/kaldi/feat/signal.cc | 129 -- speechx/speechx/kaldi/feat/signal.h | 58 - speechx/speechx/kaldi/matrix/CMakeLists.txt | 16 - speechx/speechx/kaldi/matrix/cblas-wrappers.h | 491 ----- .../speechx/kaldi/matrix/compressed-matrix.cc | 876 --------- .../speechx/kaldi/matrix/compressed-matrix.h | 283 --- speechx/speechx/kaldi/matrix/jama-eig.h | 924 --------- speechx/speechx/kaldi/matrix/jama-svd.h | 531 ------ speechx/speechx/kaldi/matrix/kaldi-blas.h | 139 -- speechx/speechx/kaldi/matrix/kaldi-vector.h | 612 ------ .../kaldi/matrix/matrix-functions-inl.h | 56 - .../speechx/kaldi/matrix/matrix-functions.cc | 773 -------- .../speechx/kaldi/matrix/matrix-functions.h | 174 -- speechx/speechx/kaldi/matrix/matrix-lib.h | 37 - speechx/speechx/kaldi/matrix/optimization.cc | 577 ------ speechx/speechx/kaldi/matrix/optimization.h | 248 --- speechx/speechx/kaldi/matrix/packed-matrix.cc | 438 ----- speechx/speechx/kaldi/matrix/packed-matrix.h | 197 -- speechx/speechx/kaldi/matrix/qr.cc | 580 ------ speechx/speechx/kaldi/matrix/sp-matrix-inl.h | 42 - speechx/speechx/kaldi/matrix/sp-matrix.cc | 1216 ------------ speechx/speechx/kaldi/matrix/sp-matrix.h | 517 ----- speechx/speechx/kaldi/matrix/sparse-matrix.cc | 1296 ------------- speechx/speechx/kaldi/matrix/sparse-matrix.h | 452 ----- speechx/speechx/kaldi/matrix/srfft.cc | 440 ----- speechx/speechx/kaldi/matrix/srfft.h | 141 -- speechx/speechx/kaldi/matrix/tp-matrix.cc | 145 -- speechx/speechx/kaldi/matrix/tp-matrix.h | 134 -- speechx/speechx/kaldi/util/kaldi-holder-inl.h | 290 +-- speechx/speechx/kaldi/util/kaldi-holder.cc | 3 +- speechx/speechx/kaldi/util/kaldi-holder.h | 15 +- speechx/speechx/kaldi/util/table-types.h | 91 +- 120 files changed, 1281 insertions(+), 20393 deletions(-) rename speechx/speechx/common/frontend/{audio => }/assembler.cc (99%) rename speechx/speechx/common/frontend/{audio => }/assembler.h (98%) delete mode 100644 speechx/speechx/common/frontend/audio/CMakeLists.txt rename speechx/speechx/common/frontend/{audio => }/audio_cache.cc (98%) rename speechx/speechx/common/frontend/{audio => }/audio_cache.h (98%) rename speechx/speechx/common/frontend/{audio => }/cmvn.cc (99%) rename speechx/speechx/common/frontend/{audio => }/cmvn.h (94%) rename speechx/speechx/common/frontend/{audio => }/compute_fbank_main.cc (96%) rename speechx/speechx/common/frontend/{audio => }/compute_linear_spectrogram_main.cc (100%) rename speechx/speechx/common/frontend/{audio => }/data_cache.h (96%) rename speechx/speechx/common/frontend/{audio => }/db_norm.cc (100%) rename speechx/speechx/common/frontend/{audio => }/db_norm.h (100%) rename speechx/speechx/common/frontend/{audio => }/fbank.cc (100%) rename speechx/speechx/common/frontend/{audio => }/fbank.h (90%) rename speechx/speechx/common/frontend/{audio => }/feature-fbank.cc (97%) rename speechx/speechx/common/frontend/{audio => }/feature-fbank.h (97%) rename speechx/speechx/common/frontend/{audio => }/feature-functions.cc (97%) rename speechx/speechx/common/frontend/{audio => }/feature-functions.h (100%) rename speechx/speechx/common/frontend/{audio => }/feature-window.cc (99%) rename speechx/speechx/common/frontend/{audio => }/feature-window.h (100%) rename speechx/speechx/common/frontend/{audio => }/feature_cache.cc (97%) rename speechx/speechx/common/frontend/{audio => }/feature_cache.h (98%) rename speechx/speechx/common/frontend/{audio => }/feature_common.h (95%) rename speechx/speechx/common/frontend/{audio => }/feature_common_inl.h (100%) rename speechx/speechx/common/frontend/{audio => }/feature_pipeline.cc (96%) rename speechx/speechx/common/frontend/{audio => }/feature_pipeline.h (93%) rename speechx/speechx/common/frontend/{audio => }/fftsg.c (100%) rename speechx/speechx/common/frontend/{audio => }/frontend_itf.h (97%) rename speechx/speechx/common/frontend/{audio => }/linear_spectrogram.cc (100%) rename speechx/speechx/common/frontend/{audio => }/linear_spectrogram.h (100%) rename speechx/speechx/common/frontend/{audio => }/mel-computations.cc (99%) rename speechx/speechx/common/frontend/{audio => }/mel-computations.h (98%) rename speechx/speechx/common/frontend/{audio => }/mfcc.cc (100%) rename speechx/speechx/common/frontend/{audio => }/mfcc.h (100%) rename speechx/speechx/common/frontend/{audio => }/normalizer.h (90%) rename speechx/speechx/common/frontend/{audio => }/rfft.cc (98%) rename speechx/speechx/common/frontend/{audio => }/rfft.h (100%) rename speechx/speechx/{kaldi/feat => common/frontend}/wave-reader.cc (99%) rename speechx/speechx/{kaldi/feat => common/frontend}/wave-reader.h (100%) create mode 100644 speechx/speechx/common/matrix/CMakeLists.txt rename speechx/speechx/{kaldi => common}/matrix/kaldi-matrix-inl.h (99%) rename speechx/speechx/{kaldi => common}/matrix/kaldi-matrix.cc (97%) rename speechx/speechx/{kaldi => common}/matrix/kaldi-matrix.h (72%) rename speechx/speechx/{kaldi => common}/matrix/kaldi-vector-inl.h (84%) rename speechx/speechx/{kaldi => common}/matrix/kaldi-vector.cc (52%) create mode 100644 speechx/speechx/common/matrix/kaldi-vector.h rename speechx/speechx/{kaldi => common}/matrix/matrix-common.h (78%) delete mode 100644 speechx/speechx/kaldi/feat/CMakeLists.txt delete mode 100644 speechx/speechx/kaldi/feat/cmvn.cc delete mode 100644 speechx/speechx/kaldi/feat/cmvn.h delete mode 100644 speechx/speechx/kaldi/feat/feature-common-inl.h delete mode 100644 speechx/speechx/kaldi/feat/feature-common.h delete mode 100644 speechx/speechx/kaldi/feat/feature-fbank.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-fbank.h delete mode 100644 speechx/speechx/kaldi/feat/feature-functions.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-functions.h delete mode 100644 speechx/speechx/kaldi/feat/feature-mfcc.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-mfcc.h delete mode 100644 speechx/speechx/kaldi/feat/feature-plp.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-plp.h delete mode 100644 speechx/speechx/kaldi/feat/feature-spectrogram.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-spectrogram.h delete mode 100644 speechx/speechx/kaldi/feat/feature-window.cc delete mode 100644 speechx/speechx/kaldi/feat/feature-window.h delete mode 100644 speechx/speechx/kaldi/feat/mel-computations.cc delete mode 100644 speechx/speechx/kaldi/feat/mel-computations.h delete mode 100644 speechx/speechx/kaldi/feat/online-feature-itf.h delete mode 100644 speechx/speechx/kaldi/feat/online-feature.cc delete mode 100644 speechx/speechx/kaldi/feat/online-feature.h delete mode 100644 speechx/speechx/kaldi/feat/pitch-functions.cc delete mode 100644 speechx/speechx/kaldi/feat/pitch-functions.h delete mode 100644 speechx/speechx/kaldi/feat/resample.cc delete mode 100644 speechx/speechx/kaldi/feat/resample.h delete mode 100644 speechx/speechx/kaldi/feat/signal.cc delete mode 100644 speechx/speechx/kaldi/feat/signal.h delete mode 100644 speechx/speechx/kaldi/matrix/CMakeLists.txt delete mode 100644 speechx/speechx/kaldi/matrix/cblas-wrappers.h delete mode 100644 speechx/speechx/kaldi/matrix/compressed-matrix.cc delete mode 100644 speechx/speechx/kaldi/matrix/compressed-matrix.h delete mode 100644 speechx/speechx/kaldi/matrix/jama-eig.h delete mode 100644 speechx/speechx/kaldi/matrix/jama-svd.h delete mode 100644 speechx/speechx/kaldi/matrix/kaldi-blas.h delete mode 100644 speechx/speechx/kaldi/matrix/kaldi-vector.h delete mode 100644 speechx/speechx/kaldi/matrix/matrix-functions-inl.h delete mode 100644 speechx/speechx/kaldi/matrix/matrix-functions.cc delete mode 100644 speechx/speechx/kaldi/matrix/matrix-functions.h delete mode 100644 speechx/speechx/kaldi/matrix/matrix-lib.h delete mode 100644 speechx/speechx/kaldi/matrix/optimization.cc delete mode 100644 speechx/speechx/kaldi/matrix/optimization.h delete mode 100644 speechx/speechx/kaldi/matrix/packed-matrix.cc delete mode 100644 speechx/speechx/kaldi/matrix/packed-matrix.h delete mode 100644 speechx/speechx/kaldi/matrix/qr.cc delete mode 100644 speechx/speechx/kaldi/matrix/sp-matrix-inl.h delete mode 100644 speechx/speechx/kaldi/matrix/sp-matrix.cc delete mode 100644 speechx/speechx/kaldi/matrix/sp-matrix.h delete mode 100644 speechx/speechx/kaldi/matrix/sparse-matrix.cc delete mode 100644 speechx/speechx/kaldi/matrix/sparse-matrix.h delete mode 100644 speechx/speechx/kaldi/matrix/srfft.cc delete mode 100644 speechx/speechx/kaldi/matrix/srfft.h delete mode 100644 speechx/speechx/kaldi/matrix/tp-matrix.cc delete mode 100644 speechx/speechx/kaldi/matrix/tp-matrix.h diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index cfce63dd..e24744d6 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -53,9 +53,6 @@ include(gflags) include(glog) -#openblas -include(openblas) - # openfst include(openfst) add_dependencies(openfst gflags glog) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index b42ca69b..bd73b3ac 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -14,7 +14,7 @@ #include "decoder/ctc_prefix_beam_search_decoder.h" #include "base/common.h" -#include "frontend/audio/data_cache.h" +#include "frontend/data_cache.h" #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" diff --git a/speechx/speechx/asr/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h index cd498e42..44c7a0c3 100644 --- a/speechx/speechx/asr/nnet/decodable.h +++ b/speechx/speechx/asr/nnet/decodable.h @@ -14,7 +14,7 @@ #include "base/common.h" #include "kaldi/decoder/decodable-itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" #include "nnet/nnet_producer.h" diff --git a/speechx/speechx/asr/nnet/nnet_itf.h b/speechx/speechx/asr/nnet/nnet_itf.h index 91d7f231..49e517ec 100644 --- a/speechx/speechx/asr/nnet/nnet_itf.h +++ b/speechx/speechx/asr/nnet/nnet_itf.h @@ -15,7 +15,6 @@ #include "base/basic_types.h" #include "kaldi/base/kaldi-types.h" -#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/options-itf.h" DECLARE_int32(subsampling_rate); diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 886c14d0..6207a6b5 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -13,10 +13,10 @@ // limitations under the License. #include "nnet/nnet_producer.h" +#include "matrix/kaldi-matrix.h" namespace ppspeech { -using kaldi::Vector; using std::vector; using kaldi::BaseFloat; diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index 953943cc..dd356f95 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -16,7 +16,7 @@ #include "base/common.h" #include "base/safe_queue.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" #include "nnet/nnet_itf.h" namespace ppspeech { diff --git a/speechx/speechx/asr/nnet/u2_nnet.h b/speechx/speechx/asr/nnet/u2_nnet.h index f7b703f6..127d84db 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.h +++ b/speechx/speechx/asr/nnet/u2_nnet.h @@ -18,7 +18,7 @@ #pragma once #include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" #include "paddle/extension.h" #include "paddle/jit/all.h" diff --git a/speechx/speechx/asr/nnet/u2_nnet_main.cc b/speechx/speechx/asr/nnet/u2_nnet_main.cc index 53fc5554..e60ae7e8 100644 --- a/speechx/speechx/asr/nnet/u2_nnet_main.cc +++ b/speechx/speechx/asr/nnet/u2_nnet_main.cc @@ -15,8 +15,8 @@ #include "base/common.h" #include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" +#include "frontend/assembler.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/u2_nnet.h" diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 17ba018f..8f9117e4 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -15,7 +15,7 @@ set(TEST_BINS foreach(bin_name IN LISTS TEST_BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-feat-common) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index c92e0b6a..a3bf8aea 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -18,7 +18,7 @@ #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_decoder.h" #include "decoder/decoder_itf.h" -#include "frontend/audio/feature_pipeline.h" +#include "frontend/feature_pipeline.h" #include "fst/fstlib.h" #include "fst/symbol-table.h" #include "nnet/decodable.h" diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index 3e64011c..90c7cc06 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" #include "recognizer/u2_recognizer.h" diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index bb72b3b6..a53b4541 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -14,7 +14,7 @@ #include "recognizer/u2_recognizer.h" #include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); diff --git a/speechx/speechx/common/CMakeLists.txt b/speechx/speechx/common/CMakeLists.txt index 00426cb5..5e0a7d57 100644 --- a/speechx/speechx/common/CMakeLists.txt +++ b/speechx/speechx/common/CMakeLists.txt @@ -4,6 +4,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/../ ) add_subdirectory(utils) +add_subdirectory(matrix) + include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/frontend ) diff --git a/speechx/speechx/common/frontend/CMakeLists.txt b/speechx/speechx/common/frontend/CMakeLists.txt index 7d10fdec..617c35e1 100644 --- a/speechx/speechx/common/frontend/CMakeLists.txt +++ b/speechx/speechx/common/frontend/CMakeLists.txt @@ -1,2 +1,28 @@ +add_library(kaldi-native-fbank-core + feature-fbank.cc + feature-functions.cc + feature-window.cc + fftsg.c + mel-computations.cc + rfft.cc +) -add_subdirectory(audio) \ No newline at end of file +add_library(frontend STATIC + cmvn.cc + audio_cache.cc + feature_cache.cc + feature_pipeline.cc + assembler.cc + wave-reader.cc +) +target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils) + +set(BINS + compute_fbank_main +) + +foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog) +endforeach() diff --git a/speechx/speechx/common/frontend/audio/assembler.cc b/speechx/speechx/common/frontend/assembler.cc similarity index 99% rename from speechx/speechx/common/frontend/audio/assembler.cc rename to speechx/speechx/common/frontend/assembler.cc index 30a650d3..5f019c42 100644 --- a/speechx/speechx/common/frontend/audio/assembler.cc +++ b/speechx/speechx/common/frontend/assembler.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/assembler.h" +#include "frontend/assembler.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/assembler.h b/speechx/speechx/common/frontend/assembler.h similarity index 98% rename from speechx/speechx/common/frontend/audio/assembler.h rename to speechx/speechx/common/frontend/assembler.h index 700e60d9..9ec28053 100644 --- a/speechx/speechx/common/frontend/audio/assembler.h +++ b/speechx/speechx/common/frontend/assembler.h @@ -15,7 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/CMakeLists.txt b/speechx/speechx/common/frontend/audio/CMakeLists.txt deleted file mode 100644 index d5396ab2..00000000 --- a/speechx/speechx/common/frontend/audio/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -add_library(kaldi-native-fbank-core - feature-fbank.cc - feature-functions.cc - feature-window.cc - fftsg.c - mel-computations.cc - rfft.cc -) - -add_library(frontend STATIC - cmvn.cc - audio_cache.cc - feature_cache.cc - feature_pipeline.cc - assembler.cc -) -target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils) - -set(BINS - compute_fbank_main -) - -foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog kaldi-feat-common) -endforeach() diff --git a/speechx/speechx/common/frontend/audio/audio_cache.cc b/speechx/speechx/common/frontend/audio_cache.cc similarity index 98% rename from speechx/speechx/common/frontend/audio/audio_cache.cc rename to speechx/speechx/common/frontend/audio_cache.cc index 2221e1c9..e03ccefa 100644 --- a/speechx/speechx/common/frontend/audio/audio_cache.cc +++ b/speechx/speechx/common/frontend/audio_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/audio_cache.h" +#include "frontend/audio_cache.h" #include "kaldi/base/timer.h" diff --git a/speechx/speechx/common/frontend/audio/audio_cache.h b/speechx/speechx/common/frontend/audio_cache.h similarity index 98% rename from speechx/speechx/common/frontend/audio/audio_cache.h rename to speechx/speechx/common/frontend/audio_cache.h index d3cfbc3f..58e5452b 100644 --- a/speechx/speechx/common/frontend/audio/audio_cache.h +++ b/speechx/speechx/common/frontend/audio_cache.h @@ -16,7 +16,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/cmvn.cc b/speechx/speechx/common/frontend/cmvn.cc similarity index 99% rename from speechx/speechx/common/frontend/audio/cmvn.cc rename to speechx/speechx/common/frontend/cmvn.cc index 58ec299c..2fac1506 100644 --- a/speechx/speechx/common/frontend/audio/cmvn.cc +++ b/speechx/speechx/common/frontend/cmvn.cc @@ -13,7 +13,7 @@ // limitations under the License. -#include "frontend/audio/cmvn.h" +#include "frontend/cmvn.h" #include "utils/file_utils.h" #include "utils/picojson.h" diff --git a/speechx/speechx/common/frontend/audio/cmvn.h b/speechx/speechx/common/frontend/cmvn.h similarity index 94% rename from speechx/speechx/common/frontend/audio/cmvn.h rename to speechx/speechx/common/frontend/cmvn.h index 261d90b2..c515b6ae 100644 --- a/speechx/speechx/common/frontend/audio/cmvn.h +++ b/speechx/speechx/common/frontend/cmvn.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "frontend/frontend_itf.h" #include "kaldi/util/options-itf.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/compute_fbank_main.cc b/speechx/speechx/common/frontend/compute_fbank_main.cc similarity index 96% rename from speechx/speechx/common/frontend/audio/compute_fbank_main.cc rename to speechx/speechx/common/frontend/compute_fbank_main.cc index fc6eb063..d7d5165c 100644 --- a/speechx/speechx/common/frontend/audio/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/compute_fbank_main.cc @@ -16,13 +16,13 @@ #include "base/flags.h" #include "base/log.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/normalizer.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" +#include "frontend/normalizer.h" +#include "frontend/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" diff --git a/speechx/speechx/common/frontend/audio/compute_linear_spectrogram_main.cc b/speechx/speechx/common/frontend/compute_linear_spectrogram_main.cc similarity index 100% rename from speechx/speechx/common/frontend/audio/compute_linear_spectrogram_main.cc rename to speechx/speechx/common/frontend/compute_linear_spectrogram_main.cc diff --git a/speechx/speechx/common/frontend/audio/data_cache.h b/speechx/speechx/common/frontend/data_cache.h similarity index 96% rename from speechx/speechx/common/frontend/audio/data_cache.h rename to speechx/speechx/common/frontend/data_cache.h index d18d444d..7a37adf4 100644 --- a/speechx/speechx/common/frontend/audio/data_cache.h +++ b/speechx/speechx/common/frontend/data_cache.h @@ -16,7 +16,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" using std::vector; diff --git a/speechx/speechx/common/frontend/audio/db_norm.cc b/speechx/speechx/common/frontend/db_norm.cc similarity index 100% rename from speechx/speechx/common/frontend/audio/db_norm.cc rename to speechx/speechx/common/frontend/db_norm.cc diff --git a/speechx/speechx/common/frontend/audio/db_norm.h b/speechx/speechx/common/frontend/db_norm.h similarity index 100% rename from speechx/speechx/common/frontend/audio/db_norm.h rename to speechx/speechx/common/frontend/db_norm.h diff --git a/speechx/speechx/common/frontend/audio/fbank.cc b/speechx/speechx/common/frontend/fbank.cc similarity index 100% rename from speechx/speechx/common/frontend/audio/fbank.cc rename to speechx/speechx/common/frontend/fbank.cc diff --git a/speechx/speechx/common/frontend/audio/fbank.h b/speechx/speechx/common/frontend/fbank.h similarity index 90% rename from speechx/speechx/common/frontend/audio/fbank.h rename to speechx/speechx/common/frontend/fbank.h index 434ae7d6..61d9c9aa 100644 --- a/speechx/speechx/common/frontend/audio/fbank.h +++ b/speechx/speechx/common/frontend/fbank.h @@ -15,8 +15,8 @@ #pragma once #include "base/common.h" -#include "frontend/audio/feature_common.h" -#include "frontend/audio/feature-fbank.h" +#include "frontend/feature_common.h" +#include "frontend/feature-fbank.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/feature-fbank.cc b/speechx/speechx/common/frontend/feature-fbank.cc similarity index 97% rename from speechx/speechx/common/frontend/audio/feature-fbank.cc rename to speechx/speechx/common/frontend/feature-fbank.cc index 7a6da943..2393e153 100644 --- a/speechx/speechx/common/frontend/audio/feature-fbank.cc +++ b/speechx/speechx/common/frontend/feature-fbank.cc @@ -18,11 +18,11 @@ // This file is copied/modified from kaldi/src/feat/feature-fbank.cc // -#include "frontend/audio/feature-fbank.h" +#include "frontend/feature-fbank.h" #include -#include "frontend/audio/feature-functions.h" +#include "frontend/feature-functions.h" namespace knf { diff --git a/speechx/speechx/common/frontend/audio/feature-fbank.h b/speechx/speechx/common/frontend/feature-fbank.h similarity index 97% rename from speechx/speechx/common/frontend/audio/feature-fbank.h rename to speechx/speechx/common/frontend/feature-fbank.h index 3c43a3c8..30085245 100644 --- a/speechx/speechx/common/frontend/audio/feature-fbank.h +++ b/speechx/speechx/common/frontend/feature-fbank.h @@ -23,9 +23,9 @@ #include -#include "frontend/audio/feature-window.h" -#include "frontend/audio/mel-computations.h" -#include "frontend/audio/rfft.h" +#include "frontend/feature-window.h" +#include "frontend/mel-computations.h" +#include "frontend/rfft.h" namespace knf { diff --git a/speechx/speechx/common/frontend/audio/feature-functions.cc b/speechx/speechx/common/frontend/feature-functions.cc similarity index 97% rename from speechx/speechx/common/frontend/audio/feature-functions.cc rename to speechx/speechx/common/frontend/feature-functions.cc index 399041e4..178c711b 100644 --- a/speechx/speechx/common/frontend/audio/feature-functions.cc +++ b/speechx/speechx/common/frontend/feature-functions.cc @@ -18,7 +18,7 @@ // This file is copied/modified from kaldi/src/feat/feature-functions.cc -#include "frontend/audio/feature-functions.h" +#include "frontend/feature-functions.h" #include #include diff --git a/speechx/speechx/common/frontend/audio/feature-functions.h b/speechx/speechx/common/frontend/feature-functions.h similarity index 100% rename from speechx/speechx/common/frontend/audio/feature-functions.h rename to speechx/speechx/common/frontend/feature-functions.h diff --git a/speechx/speechx/common/frontend/audio/feature-window.cc b/speechx/speechx/common/frontend/feature-window.cc similarity index 99% rename from speechx/speechx/common/frontend/audio/feature-window.cc rename to speechx/speechx/common/frontend/feature-window.cc index 7778a1b9..1c474ccb 100644 --- a/speechx/speechx/common/frontend/audio/feature-window.cc +++ b/speechx/speechx/common/frontend/feature-window.cc @@ -4,7 +4,7 @@ // This file is copied/modified from kaldi/src/feat/feature-window.cc -#include "frontend/audio/feature-window.h" +#include "frontend/feature-window.h" #include #include diff --git a/speechx/speechx/common/frontend/audio/feature-window.h b/speechx/speechx/common/frontend/feature-window.h similarity index 100% rename from speechx/speechx/common/frontend/audio/feature-window.h rename to speechx/speechx/common/frontend/feature-window.h diff --git a/speechx/speechx/common/frontend/audio/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc similarity index 97% rename from speechx/speechx/common/frontend/audio/feature_cache.cc rename to speechx/speechx/common/frontend/feature_cache.cc index dc60e3e4..e6ac3c23 100644 --- a/speechx/speechx/common/frontend/audio/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_cache.h" +#include "frontend/feature_cache.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h similarity index 98% rename from speechx/speechx/common/frontend/audio/feature_cache.h rename to speechx/speechx/common/frontend/feature_cache.h index 8d17151c..51816a1d 100644 --- a/speechx/speechx/common/frontend/audio/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -15,7 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/feature_common.h b/speechx/speechx/common/frontend/feature_common.h similarity index 95% rename from speechx/speechx/common/frontend/audio/feature_common.h rename to speechx/speechx/common/frontend/feature_common.h index f88dd960..7864bd30 100644 --- a/speechx/speechx/common/frontend/audio/feature_common.h +++ b/speechx/speechx/common/frontend/feature_common.h @@ -15,7 +15,7 @@ #pragma once #include "frontend_itf.h" -#include "frontend/audio/feature-window.h" +#include "frontend/feature-window.h" namespace ppspeech { @@ -52,4 +52,4 @@ class StreamingFeatureTpl : public FrontendInterface { } // namespace ppspeech -#include "frontend/audio/feature_common_inl.h" +#include "frontend/feature_common_inl.h" diff --git a/speechx/speechx/common/frontend/audio/feature_common_inl.h b/speechx/speechx/common/frontend/feature_common_inl.h similarity index 100% rename from speechx/speechx/common/frontend/audio/feature_common_inl.h rename to speechx/speechx/common/frontend/feature_common_inl.h diff --git a/speechx/speechx/common/frontend/audio/feature_pipeline.cc b/speechx/speechx/common/frontend/feature_pipeline.cc similarity index 96% rename from speechx/speechx/common/frontend/audio/feature_pipeline.cc rename to speechx/speechx/common/frontend/feature_pipeline.cc index 8344ee65..34e55a10 100644 --- a/speechx/speechx/common/frontend/audio/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/feature_pipeline.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_pipeline.h" +#include "frontend/feature_pipeline.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/feature_pipeline.h b/speechx/speechx/common/frontend/feature_pipeline.h similarity index 93% rename from speechx/speechx/common/frontend/audio/feature_pipeline.h rename to speechx/speechx/common/frontend/feature_pipeline.h index 0afb873e..ea7e2bba 100644 --- a/speechx/speechx/common/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/common/frontend/feature_pipeline.h @@ -16,13 +16,13 @@ #pragma once -#include "frontend/audio/assembler.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/normalizer.h" +#include "frontend/assembler.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" +#include "frontend/cmvn.h" // feature DECLARE_bool(fill_zero); diff --git a/speechx/speechx/common/frontend/audio/fftsg.c b/speechx/speechx/common/frontend/fftsg.c similarity index 100% rename from speechx/speechx/common/frontend/audio/fftsg.c rename to speechx/speechx/common/frontend/fftsg.c diff --git a/speechx/speechx/common/frontend/audio/frontend_itf.h b/speechx/speechx/common/frontend/frontend_itf.h similarity index 97% rename from speechx/speechx/common/frontend/audio/frontend_itf.h rename to speechx/speechx/common/frontend/frontend_itf.h index 3df8fb09..57186ec4 100644 --- a/speechx/speechx/common/frontend/audio/frontend_itf.h +++ b/speechx/speechx/common/frontend/frontend_itf.h @@ -15,7 +15,7 @@ #pragma once #include "base/basic_types.h" -#include "kaldi/matrix/kaldi-vector.h" +#include "matrix/kaldi-vector.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/audio/linear_spectrogram.cc b/speechx/speechx/common/frontend/linear_spectrogram.cc similarity index 100% rename from speechx/speechx/common/frontend/audio/linear_spectrogram.cc rename to speechx/speechx/common/frontend/linear_spectrogram.cc diff --git a/speechx/speechx/common/frontend/audio/linear_spectrogram.h b/speechx/speechx/common/frontend/linear_spectrogram.h similarity index 100% rename from speechx/speechx/common/frontend/audio/linear_spectrogram.h rename to speechx/speechx/common/frontend/linear_spectrogram.h diff --git a/speechx/speechx/common/frontend/audio/mel-computations.cc b/speechx/speechx/common/frontend/mel-computations.cc similarity index 99% rename from speechx/speechx/common/frontend/audio/mel-computations.cc rename to speechx/speechx/common/frontend/mel-computations.cc index a876368e..3998af22 100644 --- a/speechx/speechx/common/frontend/audio/mel-computations.cc +++ b/speechx/speechx/common/frontend/mel-computations.cc @@ -18,12 +18,12 @@ // This file is copied/modified from kaldi/src/feat/mel-computations.cc -#include "frontend/audio/mel-computations.h" +#include "frontend/mel-computations.h" #include #include -#include "frontend/audio/feature-window.h" +#include "frontend/feature-window.h" namespace knf { diff --git a/speechx/speechx/common/frontend/audio/mel-computations.h b/speechx/speechx/common/frontend/mel-computations.h similarity index 98% rename from speechx/speechx/common/frontend/audio/mel-computations.h rename to speechx/speechx/common/frontend/mel-computations.h index 3f1b9678..2f9938bc 100644 --- a/speechx/speechx/common/frontend/audio/mel-computations.h +++ b/speechx/speechx/common/frontend/mel-computations.h @@ -22,7 +22,7 @@ #include #include -#include "frontend/audio/feature-window.h" +#include "frontend/feature-window.h" namespace knf { diff --git a/speechx/speechx/common/frontend/audio/mfcc.cc b/speechx/speechx/common/frontend/mfcc.cc similarity index 100% rename from speechx/speechx/common/frontend/audio/mfcc.cc rename to speechx/speechx/common/frontend/mfcc.cc diff --git a/speechx/speechx/common/frontend/audio/mfcc.h b/speechx/speechx/common/frontend/mfcc.h similarity index 100% rename from speechx/speechx/common/frontend/audio/mfcc.h rename to speechx/speechx/common/frontend/mfcc.h diff --git a/speechx/speechx/common/frontend/audio/normalizer.h b/speechx/speechx/common/frontend/normalizer.h similarity index 90% rename from speechx/speechx/common/frontend/audio/normalizer.h rename to speechx/speechx/common/frontend/normalizer.h index dcf721dd..5a6ca573 100644 --- a/speechx/speechx/common/frontend/audio/normalizer.h +++ b/speechx/speechx/common/frontend/normalizer.h @@ -14,5 +14,4 @@ #pragma once -#include "frontend/audio/cmvn.h" -#include "frontend/audio/db_norm.h" \ No newline at end of file +#include "frontend/cmvn.h" \ No newline at end of file diff --git a/speechx/speechx/common/frontend/audio/rfft.cc b/speechx/speechx/common/frontend/rfft.cc similarity index 98% rename from speechx/speechx/common/frontend/audio/rfft.cc rename to speechx/speechx/common/frontend/rfft.cc index 84fbc9c4..f0a3ebc7 100644 --- a/speechx/speechx/common/frontend/audio/rfft.cc +++ b/speechx/speechx/common/frontend/rfft.cc @@ -16,7 +16,7 @@ * limitations under the License. */ -#include "frontend/audio/rfft.h" +#include "frontend/rfft.h" #include #include diff --git a/speechx/speechx/common/frontend/audio/rfft.h b/speechx/speechx/common/frontend/rfft.h similarity index 100% rename from speechx/speechx/common/frontend/audio/rfft.h rename to speechx/speechx/common/frontend/rfft.h diff --git a/speechx/speechx/kaldi/feat/wave-reader.cc b/speechx/speechx/common/frontend/wave-reader.cc similarity index 99% rename from speechx/speechx/kaldi/feat/wave-reader.cc rename to speechx/speechx/common/frontend/wave-reader.cc index f8259a3a..42bf79c6 100644 --- a/speechx/speechx/kaldi/feat/wave-reader.cc +++ b/speechx/speechx/common/frontend/wave-reader.cc @@ -25,7 +25,7 @@ #include #include -#include "feat/wave-reader.h" +#include "frontend/wave-reader.h" #include "base/kaldi-error.h" #include "base/kaldi-utils.h" diff --git a/speechx/speechx/kaldi/feat/wave-reader.h b/speechx/speechx/common/frontend/wave-reader.h similarity index 100% rename from speechx/speechx/kaldi/feat/wave-reader.h rename to speechx/speechx/common/frontend/wave-reader.h diff --git a/speechx/speechx/common/matrix/CMakeLists.txt b/speechx/speechx/common/matrix/CMakeLists.txt new file mode 100644 index 00000000..a4b34d54 --- /dev/null +++ b/speechx/speechx/common/matrix/CMakeLists.txt @@ -0,0 +1,7 @@ + +add_library(kaldi-matrix +kaldi-matrix.cc +kaldi-vector.cc +) + +target_link_libraries(kaldi-matrix kaldi-base) diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h b/speechx/speechx/common/matrix/kaldi-matrix-inl.h similarity index 99% rename from speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h rename to speechx/speechx/common/matrix/kaldi-matrix-inl.h index c2ff0079..eafbc6fb 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h +++ b/speechx/speechx/common/matrix/kaldi-matrix-inl.h @@ -28,7 +28,7 @@ namespace kaldi { template Matrix::Matrix(): MatrixBase(NULL, 0, 0, 0) { } - +/* template<> template<> void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra, const VectorBase &rb); @@ -36,6 +36,7 @@ void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra template<> template<> void MatrixBase::AddVecVec(const double alpha, const VectorBase &ra, const VectorBase &rb); +*/ template inline std::ostream & operator << (std::ostream & os, const MatrixBase & M) { diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc b/speechx/speechx/common/matrix/kaldi-matrix.cc similarity index 97% rename from speechx/speechx/kaldi/matrix/kaldi-matrix.cc rename to speechx/speechx/common/matrix/kaldi-matrix.cc index faf23cdf..e446a6bf 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc +++ b/speechx/speechx/common/matrix/kaldi-matrix.cc @@ -23,17 +23,9 @@ // limitations under the License. #include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/jama-svd.h" -#include "matrix/jama-eig.h" -#include "matrix/compressed-matrix.h" -#include "matrix/sparse-matrix.h" - -static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), - "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); namespace kaldi { - +/* template void MatrixBase::Invert(Real *log_det, Real *det_sign, bool inverse_needed) { @@ -206,29 +198,30 @@ void MatrixBase::SetMatMatDivMat(const MatrixBase& A, } } } +*/ - -template -void MatrixBase::CopyLowerToUpper() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[j * stride + i ] = data[i * stride + j]; -} +//template +//void MatrixBase::CopyLowerToUpper() { + //KALDI_ASSERT(num_rows_ == num_cols_); + //Real *data = data_; + //MatrixIndexT num_rows = num_rows_, stride = stride_; + //for (int32 i = 0; i < num_rows; i++) + //for (int32 j = 0; j < i; j++) + //data[j * stride + i ] = data[i * stride + j]; +//} -template -void MatrixBase::CopyUpperToLower() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[i * stride + j] = data[j * stride + i]; -} +//template +//void MatrixBase::CopyUpperToLower() { + //KALDI_ASSERT(num_rows_ == num_cols_); + //Real *data = data_; + //MatrixIndexT num_rows = num_rows_, stride = stride_; + //for (int32 i = 0; i < num_rows; i++) + //for (int32 j = 0; j < i; j++) + //data[i * stride + j] = data[j * stride + i]; +//} +/* template void MatrixBase::SymAddMat2(const Real alpha, const MatrixBase &A, @@ -734,7 +727,7 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, } #endif - +*/ // Copy constructor. Copies data to newly allocated memory. template Matrix::Matrix (const MatrixBase & M, @@ -898,6 +891,7 @@ template void MatrixBase::CopyFromMat(const MatrixBase & M, MatrixTransposeType Trans); +/* // Specialize the template for CopyFromSp for float, float. template<> template<> @@ -992,7 +986,7 @@ template void MatrixBase::CopyFromTp(const TpMatrix & M, MatrixTransposeType trans); - +*/ template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { if (rv.Dim() == num_rows_*num_cols_) { @@ -1076,7 +1070,6 @@ void MatrixBase::CopyColsFromVec(const VectorBase &rv) { } } - template void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIndexT row) { KALDI_ASSERT(rv.Dim() == num_cols_ && @@ -1088,7 +1081,7 @@ void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIn std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); } - +/* template void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { KALDI_ASSERT(rv.Dim() == std::min(num_cols_, num_rows_)); @@ -1096,7 +1089,7 @@ void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { Real *my_data = this->Data(); for (; rv_data != rv_end; rv_data++, my_data += (this->stride_+1)) *my_data = *rv_data; -} +}*/ template void MatrixBase::CopyColFromVec(const VectorBase &rv, @@ -1135,7 +1128,7 @@ void Matrix::Destroy() { } - +/* template void MatrixBase::MulElements(const MatrixBase &a) { KALDI_ASSERT(a.NumRows() == num_rows_ && a.NumCols() == num_cols_); @@ -1325,6 +1318,7 @@ void MatrixBase::MulColsVec(const VectorBase &scale) { } } } +*/ template void MatrixBase::SetZero() { @@ -1344,6 +1338,7 @@ void MatrixBase::Set(Real value) { } } +/* template void MatrixBase::SetUnit() { SetZero(); @@ -1374,6 +1369,7 @@ void MatrixBase::SetRandUniform() { } } } +*/ template void MatrixBase::Write(std::ostream &os, bool binary) const { @@ -1420,23 +1416,11 @@ void MatrixBase::Write(std::ostream &os, bool binary) const { template -void MatrixBase::Read(std::istream & is, bool binary, bool add) { - if (add) { - Matrix tmp(num_rows_, num_cols_); - tmp.Read(is, binary, false); // read without adding. - if (tmp.num_rows_ != this->num_rows_ || tmp.num_cols_ != this->num_cols_) - KALDI_ERR << "MatrixBase::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - this->AddMat(1.0, tmp); - return; - } - // now assume add == false. - +void MatrixBase::Read(std::istream & is, bool binary) { // In order to avoid rewriting this, we just declare a Matrix and // use it to read the data, then copy. Matrix tmp; - tmp.Read(is, binary, false); + tmp.Read(is, binary); if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { KALDI_ERR << "MatrixBase::Read, size mismatch " << NumRows() << " x " << NumCols() << " versus " @@ -1447,23 +1431,7 @@ void MatrixBase::Read(std::istream & is, bool binary, bool add) { template -void Matrix::Read(std::istream & is, bool binary, bool add) { - if (add) { - Matrix tmp; - tmp.Read(is, binary, false); // read without adding. - if (this->num_rows_ == 0) this->Resize(tmp.num_rows_, tmp.num_cols_); - else { - if (this->num_rows_ != tmp.num_rows_ || this->num_cols_ != tmp.num_cols_) { - if (tmp.num_rows_ == 0) return; // do nothing in this case. - else KALDI_ERR << "Matrix::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - } - } - this->AddMat(1.0, tmp); - return; - } - +void Matrix::Read(std::istream & is, bool binary) { // now assume add == false. MatrixIndexT pos_at_start = is.tellg(); std::ostringstream specific_error; @@ -1472,10 +1440,10 @@ void Matrix::Read(std::istream & is, bool binary, bool add) { int peekval = Peek(is, binary); if (peekval == 'C') { // This code enables us to read CompressedMatrix as a regular matrix. - CompressedMatrix compressed_mat; - compressed_mat.Read(is, binary); // at this point, add == false. - this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); - compressed_mat.CopyToMat(this); + //CompressedMatrix compressed_mat; + //compressed_mat.Read(is, binary); // at this point, add == false. + //this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); + //compressed_mat.CopyToMat(this); return; } const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); @@ -1483,7 +1451,7 @@ void Matrix::Read(std::istream & is, bool binary, bool add) { if (peekval == other_token_start) { // need to instantiate the other type to read it. typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. Matrix other(this->num_rows_, this->num_cols_); - other.Read(is, binary, false); // add is false at this point anyway. + other.Read(is, binary); // add is false at this point anyway. this->Resize(other.NumRows(), other.NumCols()); this->CopyFromMat(other); return; @@ -1672,7 +1640,7 @@ SubMatrix::SubMatrix(Real *data, } } - +/* template void MatrixBase::Add(const Real alpha) { Real *data = data_; @@ -1812,15 +1780,15 @@ void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, for(int32 i = 0; i < NumRows(); i++) (*this)(i, i) *= 1.00001; }*/ - bool ans = JamaSvd(s, U, Vt); - if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. - if (!ans) { - KALDI_ERR << "Error doing Svd"; // This one will be caught. - } -#endif - if (prescale != 1.0) s->Scale(1.0/prescale); -} - +// bool ans = JamaSvd(s, U, Vt); + //if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. + //if (!ans) { + //KALDI_ERR << "Error doing Svd"; // This one will be caught. + //} +//#endif + //if (prescale != 1.0) s->Scale(1.0/prescale); +//} +/* template void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) const { try { @@ -2052,17 +2020,18 @@ void MatrixBase::InvertDouble(Real *log_det, Real *det_sign, if (log_det) *log_det = log_det_tmp; if (det_sign) *det_sign = det_sign_tmp; } +*/ -template -void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { - mat.CopyToMat(this); -} +//template +//void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { + //mat.CopyToMat(this); +//} -template -Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { - Resize(M.NumRows(), M.NumCols(), kUndefined); - M.CopyToMat(this); -} +//template +//Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { + //Resize(M.NumRows(), M.NumCols(), kUndefined); + //M.CopyToMat(this); +//} @@ -2074,7 +2043,7 @@ void MatrixBase::InvertElements() { } } } - +/* template void MatrixBase::Transpose() { KALDI_ASSERT(num_rows_ == num_cols_); @@ -2250,7 +2219,7 @@ bool MatrixBase::Power(Real power) { (*this).AddMatMat(1.0, tmp, kNoTrans, P, kNoTrans, 0.0); return true; } - +*/ template void Matrix::Swap(Matrix *other) { std::swap(this->data_, other->data_); @@ -2258,7 +2227,7 @@ void Matrix::Swap(Matrix *other) { std::swap(this->num_rows_, other->num_rows_); std::swap(this->stride_, other->stride_); } - +/* // Repeating this comment that appeared in the header: // Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D // P^{-1}. Be careful: the relationship of D to the eigenvalues we output is @@ -2298,7 +2267,7 @@ void MatrixBase::Eig(MatrixBase *P, // INT_32 mVersion; // INT_32 mSampSize; // }; - +/* template bool ReadHtk(std::istream &is, Matrix *M_ptr, HtkHeader *header_ptr) { @@ -2821,7 +2790,7 @@ void MatrixBase::GroupMax(const MatrixBase &src) { } } } - +*/ template void MatrixBase::CopyCols(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2847,7 +2816,7 @@ void MatrixBase::CopyCols(const MatrixBase &src, } } - +/* template void MatrixBase::AddCols(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2871,8 +2840,9 @@ void MatrixBase::AddCols(const MatrixBase &src, this_data[c] += src_data[*index_ptr]; } } -} +}*/ +/* template void MatrixBase::CopyRows(const MatrixBase &src, const MatrixIndexT *indices) { @@ -3022,9 +2992,9 @@ void MatrixBase::DiffTanh(const MatrixBase &value, value_data += value_stride; diff_data += diff_stride; } -} - +}*/ +/* template template void MatrixBase::AddVecToRows(const Real alpha, const VectorBase &v) { @@ -3087,7 +3057,7 @@ template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); - +*/ //Explicit instantiation of the classes //Apparently, it seems to be necessary that the instantiation //happens at the end of the file. Otherwise, not all the member diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.h b/speechx/speechx/common/matrix/kaldi-matrix.h similarity index 72% rename from speechx/speechx/kaldi/matrix/kaldi-matrix.h rename to speechx/speechx/common/matrix/kaldi-matrix.h index 4387538c..92274487 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix.h +++ b/speechx/speechx/common/matrix/kaldi-matrix.h @@ -32,13 +32,6 @@ namespace kaldi { /// @{ \addtogroup matrix_funcs_scalar -/// We need to declare this here as it will be a friend function. -/// tr(A B), or tr(A B^T). -template -Real TraceMatMat(const MatrixBase &A, const MatrixBase &B, - MatrixTransposeType trans = kNoTrans); -/// @} - /// \addtogroup matrix_group /// @{ @@ -50,15 +43,8 @@ class MatrixBase { public: // so this child can access protected members of other instances. friend class Matrix; + friend class SubMatrix; // friend declarations for CUDA matrices (see ../cudamatrix/) - friend class CuMatrixBase; - friend class CuMatrix; - friend class CuSubMatrix; - friend class CuPackedMatrix; - friend class PackedMatrix; - friend class SparseMatrix; - friend class SparseMatrix; - friend class SparseMatrix; /// Returns number of rows (or zero for empty matrix). inline MatrixIndexT NumRows() const { return num_rows_; } @@ -127,14 +113,6 @@ class MatrixBase { /// Sets all elements to a specific value. void Set(Real); /// Sets to zero, except ones along diagonal [for non-square matrices too] - void SetUnit(); - /// Sets to random values of a normal distribution - void SetRandn(); - /// Sets to numbers uniformly distributed on (0, 1) - void SetRandUniform(); - - /* Copying functions. These do not resize the matrix! */ - /// Copy given matrix. (no resize is done). template @@ -142,21 +120,17 @@ class MatrixBase { MatrixTransposeType trans = kNoTrans); /// Copy from compressed matrix. - void CopyFromMat(const CompressedMatrix &M); - - /// Copy given spmatrix. (no resize is done). - template - void CopyFromSp(const SpMatrix &M); + //void CopyFromMat(const CompressedMatrix &M); /// Copy given tpmatrix. (no resize is done). - template - void CopyFromTp(const TpMatrix &M, - MatrixTransposeType trans = kNoTrans); + //template + //void CopyFromTp(const TpMatrix &M, + //MatrixTransposeType trans = kNoTrans); /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h - template - void CopyFromMat(const CuMatrixBase &M, - MatrixTransposeType trans = kNoTrans); + //template + //void CopyFromMat(const CuMatrixBase &M, + //MatrixTransposeType trans = kNoTrans); /// This function has two modes of operation. If v.Dim() == NumRows() * /// NumCols(), then treats the vector as a row-by-row concatenation of a @@ -165,7 +139,7 @@ class MatrixBase { void CopyRowsFromVec(const VectorBase &v); /// This version of CopyRowsFromVec is implemented in ../cudamatrix/cu-vector.cc - void CopyRowsFromVec(const CuVectorBase &v); + //void CopyRowsFromVec(const CuVectorBase &v); template void CopyRowsFromVec(const VectorBase &v); @@ -215,7 +189,7 @@ class MatrixBase { return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); } - /* Various special functions. */ +/* /// Returns sum of all elements in matrix. Real Sum() const; /// Returns trace of matrix. @@ -268,15 +242,16 @@ class MatrixBase { /// Does inversion in double precision even if matrix was not double. void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, bool inverse_needed = true); - +*/ /// Inverts all the elements of the matrix void InvertElements(); - +/* /// Transpose the matrix. This one is only /// applicable to square matrices (the one in the /// Matrix child class works also for non-square. void Transpose(); +*/ /// Copies column r from column indices[r] of src. /// As a special case, if indexes[i] == -1, sets column i to zero. /// all elements of "indices" must be in [-1, src.NumCols()-1], @@ -296,8 +271,8 @@ class MatrixBase { /// indices.size() must equal this->NumCols(), /// all elements of "reorder" must be in [-1, src.NumCols()-1], /// and src.NumRows() must equal this.NumRows() - void AddCols(const MatrixBase &src, - const MatrixIndexT *indices); + //void AddCols(const MatrixBase &src, + // const MatrixIndexT *indices); /// Copies row r of this matrix from an array of floats at the location given /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. @@ -314,30 +289,30 @@ class MatrixBase { /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). /// If indexes[r] < 0, does not add anything. all elements of "indexes" must /// be in [-1, src.NumRows()-1], and src.NumCols() must equal this.NumCols(). - void AddRows(Real alpha, - const MatrixBase &src, - const MatrixIndexT *indexes); + // void AddRows(Real alpha, + // const MatrixBase &src, + // const MatrixIndexT *indexes); /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as the /// beginning of a region of memory representing a vector of floats, of the /// same length as this.NumCols(). If src[r] is NULL, does not add anything. - void AddRows(Real alpha, const Real *const *src); + //void AddRows(Real alpha, const Real *const *src); /// For each row r of this matrix, adds it (times alpha) to the array of /// floats at the location given by dst[r]. If dst[r] is NULL, does not do /// anything for that row. Requires that none of the memory regions pointed /// to by the pointers in "dst" overlap (e.g. none of the pointers should be /// the same). - void AddToRows(Real alpha, Real *const *dst) const; + //void AddToRows(Real alpha, Real *const *dst) const; /// For each row i of *this, adds this->Row(i) to /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. /// Requires that all the indexes[i] that are >= 0 /// be distinct, otherwise the behavior is undefined. - void AddToRows(Real alpha, - const MatrixIndexT *indexes, - MatrixBase *dst) const; - + //void AddToRows(Real alpha, + // const MatrixIndexT *indexes, + // MatrixBase *dst) const; +/* inline void ApplyPow(Real power) { this -> Pow(*this, power); } @@ -374,7 +349,7 @@ class MatrixBase { inline void ApplyLog() { this -> Log(*this); } - +*/ /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is /// slightly complicated, due to the need for P to be real. In the symmetric @@ -389,9 +364,9 @@ class MatrixBase { /// instead (*this) P = P D. /// /// The non-member function CreateEigenvalueMatrix creates D from eigs_real and eigs_imag. - void Eig(MatrixBase *P, - VectorBase *eigs_real, - VectorBase *eigs_imag) const; + //void Eig(MatrixBase *P, + // VectorBase *eigs_real, + // VectorBase *eigs_imag) const; /// The Power method attempts to take the matrix to a power using a method that /// works in general for fractional and negative powers. The input matrix must @@ -400,7 +375,7 @@ class MatrixBase { /// return false and leave the matrix unchanged, if at entry the matrix had /// real negative eigenvalues (or if it had zero eigenvalues and the power was /// negative). - bool Power(Real pow); +// bool Power(Real pow); /** Singular value decomposition Major limitations: @@ -413,31 +388,32 @@ class MatrixBase { expect that S.Dim() == m, U is either NULL or m by n, and v is either NULL or n by n. The singular values are not sorted (use SortSvd for that). */ - void DestructiveSvd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt); // Destroys calling matrix. + //void DestructiveSvd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt); // Destroys calling matrix. /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is already /// transposed; the normal formulation is U diag(s) V^T. /// Null pointers for U or V mean we don't want that output (this saves /// compute). The singular values are not sorted (use SortSvd for that). - void Svd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt) const; + //void Svd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt) const; /// Compute SVD but only retain the singular values. - void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } + //void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } /// Returns smallest singular value. - Real MinSingularValue() const { - Vector tmp(std::min(NumRows(), NumCols())); - Svd(&tmp); - return tmp.Min(); - } + //Real MinSingularValue() const { + // Vector tmp(std::min(NumRows(), NumCols())); + //Svd(&tmp); + //return tmp.Min(); + //} - void TestUninitialized() const; // This function is designed so that if any element + //void TestUninitialized() const; // This function is designed so that if any element // if the matrix is uninitialized memory, valgrind will complain. /// Returns condition number by computing Svd. Works even if cols > rows. /// Returns infinity if all singular values are zero. + /* Real Cond() const; /// Returns true if matrix is Symmetric. @@ -559,7 +535,7 @@ class MatrixBase { // element-by-element, set *this = diff * (1.0 - value^2). void DiffTanh(const MatrixBase &value, const MatrixBase &diff); - +*/ /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not @@ -571,208 +547,15 @@ class MatrixBase { * SpMatrix and use Eig() function there, which uses eigenvalue decomposition * directly rather than SVD. */ - void SymPosSemiDefEig(VectorBase *s, MatrixBase *P, - Real check_thresh = 0.001); - - friend Real kaldi::TraceMatMat(const MatrixBase &A, - const MatrixBase &B, MatrixTransposeType trans); // tr (A B) - - // so it can get around const restrictions on the pointer to data_. - friend class SubMatrix; - - /// Add a scalar to each element - void Add(const Real alpha); - - /// Add a scalar to each diagonal element. - void AddToDiag(const Real alpha); - - /// *this += alpha * a * b^T - template - void AddVecVec(const Real alpha, const VectorBase &a, - const VectorBase &b); - - /// [each row of *this] += alpha * v - template - void AddVecToRows(const Real alpha, const VectorBase &v); - - /// [each col of *this] += alpha * v - template - void AddVecToCols(const Real alpha, const VectorBase &v); - - /// *this += alpha * M [or M^T] - void AddMat(const Real alpha, const MatrixBase &M, - MatrixTransposeType transA = kNoTrans); - - /// *this += alpha * A [or A^T]. - void AddSmat(Real alpha, const SparseMatrix &A, - MatrixTransposeType trans = kNoTrans); - - /// (*this) = alpha * op(A) * B + beta * (*this), where A is sparse. - /// Multiplication of sparse with dense matrix. See also AddMatSmat. - void AddSmatMat(Real alpha, const SparseMatrix &A, - MatrixTransposeType transA, const MatrixBase &B, - Real beta); - - /// (*this) = alpha * A * op(B) + beta * (*this), where B is sparse - /// and op(B) is either B or trans(B) depending on the 'transB' argument. - /// This is multiplication of a dense by a sparse matrix. See also - /// AddSmatMat. - void AddMatSmat(Real alpha, const MatrixBase &A, - const SparseMatrix &B, MatrixTransposeType transB, - Real beta); - - /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only - /// updates the lower triangle of *this. It will leave the matrix asymmetric; - /// if you need it symmetric as a regular matrix, do CopyLowerToUpper(). - void SymAddMat2(const Real alpha, const MatrixBase &M, - MatrixTransposeType transA, Real beta); - - /// *this = beta * *this + alpha * diag(v) * M [or M^T]. - /// The same as adding M but scaling each row M_i by v(i). - void AddDiagVecMat(const Real alpha, const VectorBase &v, - const MatrixBase &M, MatrixTransposeType transM, - Real beta = 1.0); - - /// *this = beta * *this + alpha * M [or M^T] * diag(v) - /// The same as adding M but scaling each column M_j by v(j). - void AddMatDiagVec(const Real alpha, - const MatrixBase &M, MatrixTransposeType transM, - VectorBase &v, - Real beta = 1.0); - - /// *this = beta * *this + alpha * A .* B (.* element by element multiplication) - void AddMatMatElements(const Real alpha, - const MatrixBase& A, - const MatrixBase& B, - const Real beta); - - /// *this += alpha * S - template - void AddSp(const Real alpha, const SpMatrix &S); - - void AddMatMat(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const MatrixBase& B, MatrixTransposeType transB, - const Real beta); - - /// *this = a * b / c (by element; when c = 0, *this = a) - void SetMatMatDivMat(const MatrixBase& A, - const MatrixBase& B, - const MatrixBase& C); - - /// A version of AddMatMat specialized for when the second argument - /// contains a lot of zeroes. - void AddMatSmat(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const MatrixBase& B, MatrixTransposeType transB, - const Real beta); - - /// A version of AddMatMat specialized for when the first argument - /// contains a lot of zeroes. - void AddSmatMat(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const MatrixBase& B, MatrixTransposeType transB, - const Real beta); - - /// this <-- beta*this + alpha*A*B*C. - void AddMatMatMat(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const MatrixBase& B, MatrixTransposeType transB, - const MatrixBase& C, MatrixTransposeType transC, - const Real beta); - - /// this <-- beta*this + alpha*SpA*B. - // This and the routines below are really - // stubs that need to be made more efficient. - void AddSpMat(const Real alpha, - const SpMatrix& A, - const MatrixBase& B, MatrixTransposeType transB, - const Real beta) { - Matrix M(A); - return AddMatMat(alpha, M, kNoTrans, B, transB, beta); - } - /// this <-- beta*this + alpha*A*B. - void AddTpMat(const Real alpha, - const TpMatrix& A, MatrixTransposeType transA, - const MatrixBase& B, MatrixTransposeType transB, - const Real beta) { - Matrix M(A); - return AddMatMat(alpha, M, transA, B, transB, beta); - } - /// this <-- beta*this + alpha*A*B. - void AddMatSp(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const SpMatrix& B, - const Real beta) { - Matrix M(B); - return AddMatMat(alpha, A, transA, M, kNoTrans, beta); - } - /// this <-- beta*this + alpha*A*B*C. - void AddSpMatSp(const Real alpha, - const SpMatrix &A, - const MatrixBase& B, MatrixTransposeType transB, - const SpMatrix& C, - const Real beta) { - Matrix M(A), N(C); - return AddMatMatMat(alpha, M, kNoTrans, B, transB, N, kNoTrans, beta); - } - /// this <-- beta*this + alpha*A*B. - void AddMatTp(const Real alpha, - const MatrixBase& A, MatrixTransposeType transA, - const TpMatrix& B, MatrixTransposeType transB, - const Real beta) { - Matrix M(B); - return AddMatMat(alpha, A, transA, M, transB, beta); - } - - /// this <-- beta*this + alpha*A*B. - void AddTpTp(const Real alpha, - const TpMatrix& A, MatrixTransposeType transA, - const TpMatrix& B, MatrixTransposeType transB, - const Real beta) { - Matrix M(A), N(B); - return AddMatMat(alpha, M, transA, N, transB, beta); - } - - /// this <-- beta*this + alpha*A*B. - // This one is more efficient, not like the others above. - void AddSpSp(const Real alpha, - const SpMatrix& A, const SpMatrix& B, - const Real beta); - - /// Copy lower triangle to upper triangle (symmetrize) - void CopyLowerToUpper(); - - /// Copy upper triangle to lower triangle (symmetrize) - void CopyUpperToLower(); - - /// This function orthogonalizes the rows of a matrix using the Gram-Schmidt - /// process. It is only applicable if NumRows() <= NumCols(). It will use - /// random number generation to fill in rows with something nonzero, in cases - /// where the original matrix was of deficient row rank. - void OrthogonalizeRows(); /// stream read. /// Use instead of stream<<*this, if you want to add to existing contents. // Will throw exception on failure. - void Read(std::istream & in, bool binary, bool add = false); + void Read(std::istream & in, bool binary); /// write to stream. void Write(std::ostream & out, bool binary) const; // Below is internal methods for Svd, user does not have to know about this. -#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) - // protected: - // Should be protected but used directly in testing routine. - // destroys *this! - void LapackGesvd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt); -#else - protected: - // destroys *this! - bool JamaSvd(VectorBase *s, MatrixBase *U, - MatrixBase *V); - -#endif protected: /// Initializer, callable only from child. @@ -827,19 +610,9 @@ class Matrix : public MatrixBase { MatrixStrideType stride_type = kDefaultStride): MatrixBase() { Resize(r, c, resize_type, stride_type); } - /// Copy constructor from CUDA matrix - /// This is defined in ../cudamatrix/cu-matrix.h - template - explicit Matrix(const CuMatrixBase &cu, - MatrixTransposeType trans = kNoTrans); - - /// Swaps the contents of *this and *other. Shallow swap. void Swap(Matrix *other); - /// Defined in ../cudamatrix/cu-matrix.cc - void Swap(CuMatrix *mat); - /// Constructor from any MatrixBase. Can also copy with transpose. /// Allocates new memory. explicit Matrix(const MatrixBase & M, @@ -853,40 +626,29 @@ class Matrix : public MatrixBase { explicit Matrix(const MatrixBase & M, MatrixTransposeType trans = kNoTrans); - /// Copy constructor taking SpMatrix... - /// It is symmetric, so no option for transpose, and NumRows == Cols - template - explicit Matrix(const SpMatrix & M) : MatrixBase() { - Resize(M.NumRows(), M.NumRows(), kUndefined); - this->CopyFromSp(M); - } - - /// Constructor from CompressedMatrix - explicit Matrix(const CompressedMatrix &C); - /// Copy constructor taking TpMatrix... - template - explicit Matrix(const TpMatrix & M, - MatrixTransposeType trans = kNoTrans) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.NumRows(), M.NumCols(), kUndefined); - this->CopyFromTp(M); - } else { - Resize(M.NumCols(), M.NumRows(), kUndefined); - this->CopyFromTp(M, kTrans); - } - } + //template + //explicit Matrix(const TpMatrix & M, + //MatrixTransposeType trans = kNoTrans) : MatrixBase() { + //if (trans == kNoTrans) { + //Resize(M.NumRows(), M.NumCols(), kUndefined); + //this->CopyFromTp(M); + //} else { + //Resize(M.NumCols(), M.NumRows(), kUndefined); + //this->CopyFromTp(M, kTrans); + //} + //} /// read from stream. // Unlike one in base, allows resizing. - void Read(std::istream & in, bool binary, bool add = false); + void Read(std::istream & in, bool binary); /// Remove a specified row. void RemoveRow(MatrixIndexT i); /// Transpose the matrix. Works for non-square /// matrices as well as square ones. - void Transpose(); + //void Transpose(); /// Distructor to free matrices. ~Matrix() { Destroy(); } @@ -947,37 +709,6 @@ class Matrix : public MatrixBase { /// A structure containing the HTK header. /// [TODO: change the style of the variables to Kaldi-compliant] -struct HtkHeader { - /// Number of samples. - int32 mNSamples; - /// Sample period. - int32 mSamplePeriod; - /// Sample size - int16 mSampleSize; - /// Sample kind. - uint16 mSampleKind; -}; - -// Read HTK formatted features from file into matrix. -template -bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); - -// Write (HTK format) features to file from matrix. -template -bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr); - -// Write (CMUSphinx format) features to file from matrix. -template -bool WriteSphinx(std::ostream &os, const MatrixBase &M); - -/// @} end of "addtogroup matrix_funcs_io" - -/** - Sub-matrix representation. - Can work with sub-parts of a matrix using this class. - Note that SubMatrix is not very const-correct-- it allows you to - change the contents of a const Matrix. Be careful! -*/ template class SubMatrix : public MatrixBase { @@ -1012,6 +743,7 @@ class SubMatrix : public MatrixBase { /// Disallow assignment. SubMatrix &operator = (const SubMatrix &other); }; + /// @} End of "addtogroup matrix_funcs_io". /// \addtogroup matrix_funcs_scalar @@ -1019,7 +751,7 @@ class SubMatrix : public MatrixBase { // Some declarations. These are traces of products. - +/************************ template bool ApproxEqual(const MatrixBase &A, const MatrixBase &B, Real tol = 0.01) { @@ -1085,7 +817,7 @@ void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase template bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); - +**********/ /// @} end of addtogroup matrix_funcs_misc @@ -1101,7 +833,6 @@ std::istream & operator >> (std::istream & In, MatrixBase & M); template std::istream & operator >> (std::istream & In, Matrix & M); - template bool SameDim(const MatrixBase &M, const MatrixBase &N) { return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h b/speechx/speechx/common/matrix/kaldi-vector-inl.h similarity index 84% rename from speechx/speechx/kaldi/matrix/kaldi-vector-inl.h rename to speechx/speechx/common/matrix/kaldi-vector-inl.h index c3a4f52f..82620276 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h +++ b/speechx/speechx/common/matrix/kaldi-vector-inl.h @@ -44,14 +44,14 @@ std::istream &operator >> (std::istream &is, Vector &rv) { return is; } -template<> -template<> -void VectorBase::AddVec(const float alpha, const VectorBase &rv); - -template<> -template<> -void VectorBase::AddVec(const double alpha, - const VectorBase &rv); +//template<> +//template<> +//void VectorBase::AddVec(const float alpha, const VectorBase &rv); + +//template<> +//template<> +//void VectorBase::AddVec(const double alpha, + //const VectorBase &rv); } // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector.cc b/speechx/speechx/common/matrix/kaldi-vector.cc similarity index 52% rename from speechx/speechx/kaldi/matrix/kaldi-vector.cc rename to speechx/speechx/common/matrix/kaldi-vector.cc index ccc7e89b..9f2bd08e 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-vector.cc +++ b/speechx/speechx/common/matrix/kaldi-vector.cc @@ -25,144 +25,11 @@ #include #include -#include "matrix/cblas-wrappers.h" #include "matrix/kaldi-vector.h" #include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/sparse-matrix.h" namespace kaldi { -template -Real VecVec(const VectorBase &a, - const VectorBase &b) { - MatrixIndexT adim = a.Dim(); - KALDI_ASSERT(adim == b.Dim()); - return cblas_Xdot(adim, a.Data(), 1, b.Data(), 1); -} - -template -float VecVec<>(const VectorBase &a, - const VectorBase &b); -template -double VecVec<>(const VectorBase &a, - const VectorBase &b); - -template -Real VecVec(const VectorBase &ra, - const VectorBase &rb) { - MatrixIndexT adim = ra.Dim(); - KALDI_ASSERT(adim == rb.Dim()); - const Real *a_data = ra.Data(); - const OtherReal *b_data = rb.Data(); - Real sum = 0.0; - for (MatrixIndexT i = 0; i < adim; i++) - sum += a_data[i]*b_data[i]; - return sum; -} - -// instantiate the template above. -template -float VecVec<>(const VectorBase &ra, - const VectorBase &rb); -template -double VecVec<>(const VectorBase &ra, - const VectorBase &rb); - - -template<> -template<> -void VectorBase::AddVec(const float alpha, - const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - KALDI_ASSERT(&v != this); - cblas_Xaxpy(dim_, alpha, v.Data(), 1, data_, 1); -} - -template<> -template<> -void VectorBase::AddVec(const double alpha, - const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - KALDI_ASSERT(&v != this); - cblas_Xaxpy(dim_, alpha, v.Data(), 1, data_, 1); -} - -template -void VectorBase::AddMatVec(const Real alpha, - const MatrixBase &M, - MatrixTransposeType trans, - const VectorBase &v, - const Real beta) { - KALDI_ASSERT((trans == kNoTrans && M.NumCols() == v.dim_ && M.NumRows() == dim_) - || (trans == kTrans && M.NumRows() == v.dim_ && M.NumCols() == dim_)); - KALDI_ASSERT(&v != this); - cblas_Xgemv(trans, M.NumRows(), M.NumCols(), alpha, M.Data(), M.Stride(), - v.Data(), 1, beta, data_, 1); -} - -template -void VectorBase::AddMatSvec(const Real alpha, - const MatrixBase &M, - MatrixTransposeType trans, - const VectorBase &v, - const Real beta) { - KALDI_ASSERT((trans == kNoTrans && M.NumCols() == v.dim_ && M.NumRows() == dim_) - || (trans == kTrans && M.NumRows() == v.dim_ && M.NumCols() == dim_)); - KALDI_ASSERT(&v != this); - Xgemv_sparsevec(trans, M.NumRows(), M.NumCols(), alpha, M.Data(), M.Stride(), - v.Data(), 1, beta, data_, 1); - return; - /* - MatrixIndexT this_dim = this->dim_, v_dim = v.dim_, - M_stride = M.Stride(); - Real *this_data = this->data_; - const Real *M_data = M.Data(), *v_data = v.data_; - if (beta != 1.0) this->Scale(beta); - if (trans == kNoTrans) { - for (MatrixIndexT i = 0; i < v_dim; i++) { - Real v_i = v_data[i]; - if (v_i == 0.0) continue; - // Add to *this, the i'th column of the Matrix, times v_i. - cblas_Xaxpy(this_dim, v_i * alpha, M_data + i, M_stride, this_data, 1); - } - } else { // The transposed case is slightly more efficient, I guess. - for (MatrixIndexT i = 0; i < v_dim; i++) { - Real v_i = v.data_[i]; - if (v_i == 0.0) continue; - // Add to *this, the i'th row of the Matrix, times v_i. - cblas_Xaxpy(this_dim, v_i * alpha, - M_data + (i * M_stride), 1, this_data, 1); - } - }*/ -} - -template -void VectorBase::AddSpVec(const Real alpha, - const SpMatrix &M, - const VectorBase &v, - const Real beta) { - KALDI_ASSERT(M.NumRows() == v.dim_ && dim_ == v.dim_); - KALDI_ASSERT(&v != this); - cblas_Xspmv(alpha, M.NumRows(), M.Data(), v.Data(), 1, beta, data_, 1); -} - - -template -void VectorBase::MulTp(const TpMatrix &M, - const MatrixTransposeType trans) { - KALDI_ASSERT(M.NumRows() == dim_); - cblas_Xtpmv(trans,M.Data(),M.NumRows(),data_,1); -} - -template -void VectorBase::Solve(const TpMatrix &M, - const MatrixTransposeType trans) { - KALDI_ASSERT(M.NumRows() == dim_); - cblas_Xtpsv(trans, M.Data(), M.NumRows(), data_, 1); -} - - template inline void Vector::Init(const MatrixIndexT dim) { KALDI_ASSERT(dim >= 0); @@ -232,6 +99,7 @@ void VectorBase::CopyFromVec(const VectorBase &v) { } } +/* template template void VectorBase::CopyFromPacked(const PackedMatrix& M) { @@ -249,7 +117,7 @@ template void VectorBase::CopyFromPtr(const Real *data, MatrixIndexT sz) { KALDI_ASSERT(dim_ == sz); std::memcpy(this->data_, data, Dim() * sizeof(Real)); -} +}*/ template template @@ -297,6 +165,7 @@ bool VectorBase::IsZero(Real cutoff) const { return (abs_max <= cutoff); } +/* template void VectorBase::SetRandn() { kaldi::RandomState rstate; @@ -330,7 +199,7 @@ MatrixIndexT VectorBase::RandCategorical() const { } return dim_ - 1; // Should only happen if RandUniform() // returns exactly 1, or due to roundoff. -} +}*/ template void VectorBase::Set(Real f) { @@ -426,6 +295,7 @@ void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixInde template void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row); +/* template template void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT row) { @@ -451,28 +321,6 @@ void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT r template void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); - -#ifdef HAVE_MKL -template<> -void VectorBase::Pow(const VectorBase &v, float power) { - vsPowx(dim_, data_, power, v.data_); -} -template<> -void VectorBase::Pow(const VectorBase &v, double power) { - vdPowx(dim_, data_, power, v.data_); -} -#else - -// takes elements to a power. Does not check output. -template -void VectorBase::Pow(const VectorBase &v, Real power) { - KALDI_ASSERT(dim_ == v.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = pow(v.data_[i], power); - } -} -#endif - // takes absolute value of the elements to a power. // Throws exception if could not (but only for power != 1 and power != 2). template @@ -648,7 +496,7 @@ Real VectorBase::Min(MatrixIndexT *index_out) const { if (data[i] < ans) { ans = data[i]; index = i; } *index_out = index; return ans; -} +}*/ template @@ -670,434 +518,424 @@ void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixInde template void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); -template -void VectorBase::CopyDiagFromMat(const MatrixBase &M) { - KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); - cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); -} - -template -void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { - KALDI_ASSERT(dim_ == M.NumCols()); - for (MatrixIndexT i = 0; i < dim_; i++) - data_[i] = M(i, i); - // could make this more efficient. -} - -template -Real VectorBase::Sum() const { - // Do a dot-product with a size-1 array with a stride of 0 to - // implement sum. This allows us to access SIMD operations in a - // cross-platform way via your BLAS library. - Real one(1); - return cblas_Xdot(dim_, data_, 1, &one, 0); -} - -template -Real VectorBase::SumLog() const { - double sum_log = 0.0; - double prod = 1.0; - for (MatrixIndexT i = 0; i < dim_; i++) { - prod *= data_[i]; - // Possible future work (arnab): change these magic values to pre-defined - // constants - if (prod < 1.0e-10 || prod > 1.0e+10) { - sum_log += Log(prod); - prod = 1.0; - } - } - if (prod != 1.0) sum_log += Log(prod); - return sum_log; -} - -template -void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, Real beta) { - KALDI_ASSERT(dim_ == M.NumCols()); - MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; - Real *data = data_; - - // implement the function according to a dimension cutoff for computation efficiency - if (num_rows <= 64) { - cblas_Xscal(dim, beta, data, 1); - const Real *m_data = M.Data(); - for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) - cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); - - } else { - Vector ones(M.NumRows()); - ones.Set(1.0); - this->AddMatVec(alpha, M, kTrans, ones, beta); - } -} - -template -void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, Real beta) { - KALDI_ASSERT(dim_ == M.NumRows()); - MatrixIndexT num_cols = M.NumCols(); - - // implement the function according to a dimension cutoff for computation efficiency - if (num_cols <= 64) { - for (MatrixIndexT i = 0; i < dim_; i++) { - double sum = 0.0; - const Real *src = M.RowData(i); - for (MatrixIndexT j = 0; j < num_cols; j++) - sum += src[j]; - data_[i] = alpha * sum + beta * data_[i]; - } - } else { - Vector ones(M.NumCols()); - ones.Set(1.0); - this->AddMatVec(alpha, M, kNoTrans, ones, beta); - } -} - -template -Real VectorBase::LogSumExp(Real prune) const { - Real sum; - if (sizeof(sum) == 8) sum = kLogZeroDouble; - else sum = kLogZeroFloat; - Real max_elem = Max(), cutoff; - if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; - else cutoff = max_elem + kMinLogDiffDouble; - if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... - cutoff = max_elem - prune; - - double sum_relto_max_elem = 0.0; - - for (MatrixIndexT i = 0; i < dim_; i++) { - BaseFloat f = data_[i]; - if (f >= cutoff) - sum_relto_max_elem += Exp(f - max_elem); - } - return max_elem + Log(sum_relto_max_elem); -} - -template -void VectorBase::InvertElements() { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = static_cast(1 / data_[i]); - } -} - -template -void VectorBase::ApplyLog() { - for (MatrixIndexT i = 0; i < dim_; i++) { - if (data_[i] < 0.0) - KALDI_ERR << "Trying to take log of a negative number."; - data_[i] = Log(data_[i]); - } -} - -template -void VectorBase::ApplyLogAndCopy(const VectorBase &v) { - KALDI_ASSERT(dim_ == v.Dim()); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = Log(v(i)); - } -} - -template -void VectorBase::ApplyExp() { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = Exp(data_[i]); - } -} - -template -void VectorBase::ApplyAbs() { - for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } -} - -template -void VectorBase::Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count) { - KALDI_ASSERT(dim_ == v.dim_); - if (floored_count == nullptr) { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = std::max(v.data_[i], floor_val); - } - } else { - MatrixIndexT num_floored = 0; - for (MatrixIndexT i = 0; i < dim_; i++) { - if (v.data_[i] < floor_val) { - data_[i] = floor_val; - num_floored++; - } else { - data_[i] = v.data_[i]; - } - } - *floored_count = num_floored; - } -} - -template -void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count) { - KALDI_ASSERT(dim_ == v.dim_); - if (ceiled_count == nullptr) { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = std::min(v.data_[i], ceil_val); - } - } else { - MatrixIndexT num_changed = 0; - for (MatrixIndexT i = 0; i < dim_; i++) { - if (v.data_[i] > ceil_val) { - data_[i] = ceil_val; - num_changed++; - } else { - data_[i] = v.data_[i]; - } - } - *ceiled_count = num_changed; - } -} - -template -MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) { - KALDI_ASSERT(floor_vec.Dim() == dim_); - MatrixIndexT num_floored = 0; - for (MatrixIndexT i = 0; i < dim_; i++) { - if (data_[i] < floor_vec(i)) { - data_[i] = floor_vec(i); - num_floored++; - } - } - return num_floored; -} - -template -Real VectorBase::ApplySoftMax() { - Real max = this->Max(), sum = 0.0; - for (MatrixIndexT i = 0; i < dim_; i++) { - sum += (data_[i] = Exp(data_[i] - max)); - } - this->Scale(1.0 / sum); - return max + Log(sum); -} - -template -Real VectorBase::ApplyLogSoftMax() { - Real max = this->Max(), sum = 0.0; - for (MatrixIndexT i = 0; i < dim_; i++) { - sum += Exp((data_[i] -= max)); - } - sum = Log(sum); - this->Add(-1.0 * sum); - return max + sum; -} - -#ifdef HAVE_MKL -template<> -void VectorBase::Tanh(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - vsTanh(dim_, src.data_, data_); -} -template<> -void VectorBase::Tanh(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - vdTanh(dim_, src.data_, data_); -} -#else -template -void VectorBase::Tanh(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) { - Real x = src.data_[i]; - if (x > 0.0) { - Real inv_expx = Exp(-x); - x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); - } else { - Real expx = Exp(x); - x = 1.0 - 2.0 / (1.0 + expx * expx); - } - data_[i] = x; - } -} -#endif - -#ifdef HAVE_MKL -// Implementing sigmoid based on tanh. -template<> -void VectorBase::Sigmoid(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - this->CopyFromVec(src); - this->Scale(0.5); - vsTanh(dim_, data_, data_); - this->Add(1.0); - this->Scale(0.5); -} -template<> -void VectorBase::Sigmoid(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - this->CopyFromVec(src); - this->Scale(0.5); - vdTanh(dim_, data_, data_); - this->Add(1.0); - this->Scale(0.5); -} -#else -template -void VectorBase::Sigmoid(const VectorBase &src) { - KALDI_ASSERT(dim_ == src.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) { - Real x = src.data_[i]; - // We aim to avoid floating-point overflow here. - if (x > 0.0) { - x = 1.0 / (1.0 + Exp(-x)); - } else { - Real ex = Exp(x); - x = ex / (ex + 1.0); - } - data_[i] = x; - } -} -#endif - - -template -void VectorBase::Add(Real c) { - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] += c; - } -} - -template -void VectorBase::Scale(Real alpha) { - cblas_Xscal(dim_, alpha, data_, 1); -} - -template -void VectorBase::MulElements(const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] *= v.data_[i]; - } -} - -template // Set each element to y = (x == orig ? changed : x). -void VectorBase::ReplaceValue(Real orig, Real changed) { - Real *data = data_; - for (MatrixIndexT i = 0; i < dim_; i++) - if (data[i] == orig) data[i] = changed; -} - - -template -template -void VectorBase::MulElements(const VectorBase &v) { - KALDI_ASSERT(dim_ == v.Dim()); - const OtherReal *other_ptr = v.Data(); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] *= other_ptr[i]; - } -} -// instantiate template. -template -void VectorBase::MulElements(const VectorBase &v); -template -void VectorBase::MulElements(const VectorBase &v); - - -template -void VectorBase::AddVecVec(Real alpha, const VectorBase &v, - const VectorBase &r, Real beta) { - KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); - // We pretend that v is a band-diagonal matrix. - KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); - cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, - r.data_, 1, beta, this->data_, 1); -} - - -template -void VectorBase::DivElements(const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] /= v.data_[i]; - } -} - -template -template -void VectorBase::DivElements(const VectorBase &v) { - KALDI_ASSERT(dim_ == v.Dim()); - const OtherReal *other_ptr = v.Data(); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] /= other_ptr[i]; - } -} -// instantiate template. -template -void VectorBase::DivElements(const VectorBase &v); -template -void VectorBase::DivElements(const VectorBase &v); - -template -void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, - const VectorBase &rr, Real beta) { - KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); - for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; - } -} - -template -template -void VectorBase::AddVec(const Real alpha, const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - // remove __restrict__ if it causes compilation problems. - Real *__restrict__ data = data_; - OtherReal *__restrict__ other_data = v.data_; - MatrixIndexT dim = dim_; - if (alpha != 1.0) - for (MatrixIndexT i = 0; i < dim; i++) - data[i] += alpha * other_data[i]; - else - for (MatrixIndexT i = 0; i < dim; i++) - data[i] += other_data[i]; -} - -template -void VectorBase::AddVec(const float alpha, const VectorBase &v); -template -void VectorBase::AddVec(const double alpha, const VectorBase &v); - -template -template -void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - // remove __restrict__ if it causes compilation problems. - Real *__restrict__ data = data_; - OtherReal *__restrict__ other_data = v.data_; - MatrixIndexT dim = dim_; - if (alpha != 1.0) - for (MatrixIndexT i = 0; i < dim; i++) - data[i] += alpha * other_data[i] * other_data[i]; - else - for (MatrixIndexT i = 0; i < dim; i++) - data[i] += other_data[i] * other_data[i]; -} - -template -void VectorBase::AddVec2(const float alpha, const VectorBase &v); -template -void VectorBase::AddVec2(const double alpha, const VectorBase &v); - - -template -void VectorBase::Read(std::istream &is, bool binary, bool add) { - if (add) { - Vector tmp(Dim()); - tmp.Read(is, binary, false); // read without adding. - if (this->Dim() != tmp.Dim()) { - KALDI_ERR << "VectorBase::Read, size mismatch " << this->Dim()<<" vs. "<AddVec(1.0, tmp); - return; - } // now assume add == false. - +//template +//void VectorBase::CopyDiagFromMat(const MatrixBase &M) { + //KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); + //cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); +//} + +//template +//void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { + //KALDI_ASSERT(dim_ == M.NumCols()); + //for (MatrixIndexT i = 0; i < dim_; i++) + //data_[i] = M(i, i); + //// could make this more efficient. +//} + +//template +//Real VectorBase::Sum() const { + //// Do a dot-product with a size-1 array with a stride of 0 to + //// implement sum. This allows us to access SIMD operations in a + //// cross-platform way via your BLAS library. + //Real one(1); + //return cblas_Xdot(dim_, data_, 1, &one, 0); +//} + +//template +//Real VectorBase::SumLog() const { + //double sum_log = 0.0; + //double prod = 1.0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //prod *= data_[i]; + //// Possible future work (arnab): change these magic values to pre-defined + //// constants + //if (prod < 1.0e-10 || prod > 1.0e+10) { + //sum_log += Log(prod); + //prod = 1.0; + //} + //} + //if (prod != 1.0) sum_log += Log(prod); + //return sum_log; +//} + +//template +//void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, Real beta) { + //KALDI_ASSERT(dim_ == M.NumCols()); + //MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; + //Real *data = data_; + + //// implement the function according to a dimension cutoff for computation efficiency + //if (num_rows <= 64) { + //cblas_Xscal(dim, beta, data, 1); + //const Real *m_data = M.Data(); + //for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) + //cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); + + //} else { + //Vector ones(M.NumRows()); + //ones.Set(1.0); + //this->AddMatVec(alpha, M, kTrans, ones, beta); + //} +//} + +//template +//void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, Real beta) { + //KALDI_ASSERT(dim_ == M.NumRows()); + //MatrixIndexT num_cols = M.NumCols(); + + //// implement the function according to a dimension cutoff for computation efficiency + //if (num_cols <= 64) { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //double sum = 0.0; + //const Real *src = M.RowData(i); + //for (MatrixIndexT j = 0; j < num_cols; j++) + //sum += src[j]; + //data_[i] = alpha * sum + beta * data_[i]; + //} + //} else { + //Vector ones(M.NumCols()); + //ones.Set(1.0); + //this->AddMatVec(alpha, M, kNoTrans, ones, beta); + //} +//} + +//template +//Real VectorBase::LogSumExp(Real prune) const { + //Real sum; + //if (sizeof(sum) == 8) sum = kLogZeroDouble; + //else sum = kLogZeroFloat; + //Real max_elem = Max(), cutoff; + //if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; + //else cutoff = max_elem + kMinLogDiffDouble; + //if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... + //cutoff = max_elem - prune; + + //double sum_relto_max_elem = 0.0; + + //for (MatrixIndexT i = 0; i < dim_; i++) { + //BaseFloat f = data_[i]; + //if (f >= cutoff) + //sum_relto_max_elem += Exp(f - max_elem); + //} + //return max_elem + Log(sum_relto_max_elem); +//} + +//template +//void VectorBase::InvertElements() { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = static_cast(1 / data_[i]); + //} +//} + +//template +//void VectorBase::ApplyLog() { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //if (data_[i] < 0.0) + //KALDI_ERR << "Trying to take log of a negative number."; + //data_[i] = Log(data_[i]); + //} +//} + +//template +//void VectorBase::ApplyLogAndCopy(const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.Dim()); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = Log(v(i)); + //} +//} + +//template +//void VectorBase::ApplyExp() { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = Exp(data_[i]); + //} +//} + +//template +//void VectorBase::ApplyAbs() { + //for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } +//} + +//template +//void VectorBase::Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count) { + //KALDI_ASSERT(dim_ == v.dim_); + //if (floored_count == nullptr) { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = std::max(v.data_[i], floor_val); + //} + //} else { + //MatrixIndexT num_floored = 0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //if (v.data_[i] < floor_val) { + //data_[i] = floor_val; + //num_floored++; + //} else { + //data_[i] = v.data_[i]; + //} + //} + //*floored_count = num_floored; + //} +//} + +//template +//void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count) { + //KALDI_ASSERT(dim_ == v.dim_); + //if (ceiled_count == nullptr) { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = std::min(v.data_[i], ceil_val); + //} + //} else { + //MatrixIndexT num_changed = 0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //if (v.data_[i] > ceil_val) { + //data_[i] = ceil_val; + //num_changed++; + //} else { + //data_[i] = v.data_[i]; + //} + //} + //*ceiled_count = num_changed; + //} +//} + +//template +//MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) { + //KALDI_ASSERT(floor_vec.Dim() == dim_); + //MatrixIndexT num_floored = 0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //if (data_[i] < floor_vec(i)) { + //data_[i] = floor_vec(i); + //num_floored++; + //} + //} + //return num_floored; +//} + +//template +//Real VectorBase::ApplySoftMax() { + //Real max = this->Max(), sum = 0.0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //sum += (data_[i] = Exp(data_[i] - max)); + //} + //this->Scale(1.0 / sum); + //return max + Log(sum); +//} + +//template +//Real VectorBase::ApplyLogSoftMax() { + //Real max = this->Max(), sum = 0.0; + //for (MatrixIndexT i = 0; i < dim_; i++) { + //sum += Exp((data_[i] -= max)); + //} + //sum = Log(sum); + //this->Add(-1.0 * sum); + //return max + sum; +//} + +//#ifdef HAVE_MKL +//template<> +//void VectorBase::Tanh(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //vsTanh(dim_, src.data_, data_); +//} +//template<> +//void VectorBase::Tanh(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //vdTanh(dim_, src.data_, data_); +//} +//#else +//template +//void VectorBase::Tanh(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //Real x = src.data_[i]; + //if (x > 0.0) { + //Real inv_expx = Exp(-x); + //x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); + //} else { + //Real expx = Exp(x); + //x = 1.0 - 2.0 / (1.0 + expx * expx); + //} + //data_[i] = x; + //} +//} +//#endif + +//#ifdef HAVE_MKL +//// Implementing sigmoid based on tanh. +//template<> +//void VectorBase::Sigmoid(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //this->CopyFromVec(src); + //this->Scale(0.5); + //vsTanh(dim_, data_, data_); + //this->Add(1.0); + //this->Scale(0.5); +//} +//template<> +//void VectorBase::Sigmoid(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //this->CopyFromVec(src); + //this->Scale(0.5); + //vdTanh(dim_, data_, data_); + //this->Add(1.0); + //this->Scale(0.5); +//} +//#else +//template +//void VectorBase::Sigmoid(const VectorBase &src) { + //KALDI_ASSERT(dim_ == src.dim_); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //Real x = src.data_[i]; + //// We aim to avoid floating-point overflow here. + //if (x > 0.0) { + //x = 1.0 / (1.0 + Exp(-x)); + //} else { + //Real ex = Exp(x); + //x = ex / (ex + 1.0); + //} + //data_[i] = x; + //} +//} +//#endif + + +//template +//void VectorBase::Add(Real c) { + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] += c; + //} +//} + +//template +//void VectorBase::Scale(Real alpha) { + //cblas_Xscal(dim_, alpha, data_, 1); +//} + +//template +//void VectorBase::MulElements(const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.dim_); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] *= v.data_[i]; + //} +//} + +//template // Set each element to y = (x == orig ? changed : x). +//void VectorBase::ReplaceValue(Real orig, Real changed) { + //Real *data = data_; + //for (MatrixIndexT i = 0; i < dim_; i++) + //if (data[i] == orig) data[i] = changed; +//} + + +//template +//template +//void VectorBase::MulElements(const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.Dim()); + //const OtherReal *other_ptr = v.Data(); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] *= other_ptr[i]; + //} +//} +//// instantiate template. +//template +//void VectorBase::MulElements(const VectorBase &v); +//template +//void VectorBase::MulElements(const VectorBase &v); + + +//template +//void VectorBase::AddVecVec(Real alpha, const VectorBase &v, + //const VectorBase &r, Real beta) { + //KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); + //// We pretend that v is a band-diagonal matrix. + //KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); + //cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, + //r.data_, 1, beta, this->data_, 1); +//} + + +//template +//void VectorBase::DivElements(const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.dim_); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] /= v.data_[i]; + //} +//} + +//template +//template +//void VectorBase::DivElements(const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.Dim()); + //const OtherReal *other_ptr = v.Data(); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] /= other_ptr[i]; + //} +//} +//// instantiate template. +//template +//void VectorBase::DivElements(const VectorBase &v); +//template +//void VectorBase::DivElements(const VectorBase &v); + +//template +//void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, + //const VectorBase &rr, Real beta) { + //KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); + //for (MatrixIndexT i = 0; i < dim_; i++) { + //data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; + //} +//} + +//template +//template +//void VectorBase::AddVec(const Real alpha, const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.dim_); + //// remove __restrict__ if it causes compilation problems. + //Real *__restrict__ data = data_; + //OtherReal *__restrict__ other_data = v.data_; + //MatrixIndexT dim = dim_; + //if (alpha != 1.0) + //for (MatrixIndexT i = 0; i < dim; i++) + //data[i] += alpha * other_data[i]; + //else + //for (MatrixIndexT i = 0; i < dim; i++) + //data[i] += other_data[i]; +//} + +//template +//void VectorBase::AddVec(const float alpha, const VectorBase &v); +//template +//void VectorBase::AddVec(const double alpha, const VectorBase &v); + +//template +//template +//void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.dim_); + //// remove __restrict__ if it causes compilation problems. + //Real *__restrict__ data = data_; + //OtherReal *__restrict__ other_data = v.data_; + //MatrixIndexT dim = dim_; + //if (alpha != 1.0) + //for (MatrixIndexT i = 0; i < dim; i++) + //data[i] += alpha * other_data[i] * other_data[i]; + //else + //for (MatrixIndexT i = 0; i < dim; i++) + //data[i] += other_data[i] * other_data[i]; +//} + +//template +//void VectorBase::AddVec2(const float alpha, const VectorBase &v); +//template +//void VectorBase::AddVec2(const double alpha, const VectorBase &v); + + +template +void VectorBase::Read(std::istream &is, bool binary) { // In order to avoid rewriting this, we just declare a Vector and // use it to read the data, then copy. Vector tmp; - tmp.Read(is, binary, false); + tmp.Read(is, binary); if (tmp.Dim() != Dim()) KALDI_ERR << "VectorBase::Read, size mismatch " << Dim() << " vs. " << tmp.Dim(); @@ -1106,19 +944,7 @@ void VectorBase::Read(std::istream &is, bool binary, bool add) { template -void Vector::Read(std::istream &is, bool binary, bool add) { - if (add) { - Vector tmp(this->Dim()); - tmp.Read(is, binary, false); // read without adding. - if (this->Dim() == 0) this->Resize(tmp.Dim()); - if (this->Dim() != tmp.Dim()) { - KALDI_ERR << "Vector::Read, adding but dimensions mismatch " - << this->Dim() << " vs. " << tmp.Dim(); - } - this->AddVec(1.0, tmp); - return; - } // now assume add == false. - +void Vector::Read(std::istream &is, bool binary) { std::ostringstream specific_error; MatrixIndexT pos_at_start = is.tellg(); @@ -1129,7 +955,7 @@ void Vector::Read(std::istream &is, bool binary, bool add) { if (peekval == other_token_start) { // need to instantiate the other type to read it. typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. Vector other(this->Dim()); - other.Read(is, binary, false); // add is false at this point. + other.Read(is, binary); // add is false at this point. if (this->Dim() != other.Dim()) this->Resize(other.Dim()); this->CopyFromVec(other); return; @@ -1251,47 +1077,47 @@ void VectorBase::Write(std::ostream & os, bool binary) const { } -template -void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { - KALDI_ASSERT(dim_ == v.dim_); - for (MatrixIndexT i = 0; i < dim_; i++) - data_[i] += alpha * v.data_[i] * v.data_[i]; -} - -// this <-- beta*this + alpha*M*v. -template -void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, - const MatrixTransposeType trans, - const VectorBase &v, - const Real beta) { - KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); - if (beta == 0.0) { - if (&v != this) CopyFromVec(v); - MulTp(M, trans); - if (alpha != 1.0) Scale(alpha); - } else { - Vector tmp(v); - tmp.MulTp(M, trans); - if (beta != 1.0) Scale(beta); // *this <-- beta * *this - AddVec(alpha, tmp); // *this += alpha * M * v - } -} - -template -Real VecMatVec(const VectorBase &v1, const MatrixBase &M, - const VectorBase &v2) { - KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); - Vector vtmp(M.NumRows()); - vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); - return VecVec(v1, vtmp); -} - -template -float VecMatVec(const VectorBase &v1, const MatrixBase &M, - const VectorBase &v2); -template -double VecMatVec(const VectorBase &v1, const MatrixBase &M, - const VectorBase &v2); +//template +//void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { + //KALDI_ASSERT(dim_ == v.dim_); + //for (MatrixIndexT i = 0; i < dim_; i++) + //data_[i] += alpha * v.data_[i] * v.data_[i]; +//} + +//// this <-- beta*this + alpha*M*v. +//template +//void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, + //const MatrixTransposeType trans, + //const VectorBase &v, + //const Real beta) { + //KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); + //if (beta == 0.0) { + //if (&v != this) CopyFromVec(v); + //MulTp(M, trans); + //if (alpha != 1.0) Scale(alpha); + //} else { + //Vector tmp(v); + //tmp.MulTp(M, trans); + //if (beta != 1.0) Scale(beta); // *this <-- beta * *this + //AddVec(alpha, tmp); // *this += alpha * M * v + //} +//} + +//template +//Real VecMatVec(const VectorBase &v1, const MatrixBase &M, + //const VectorBase &v2) { + //KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); + //Vector vtmp(M.NumRows()); + //vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); + //return VecVec(v1, vtmp); +//} + +//template +//float VecMatVec(const VectorBase &v1, const MatrixBase &M, + //const VectorBase &v2); +//template +//double VecMatVec(const VectorBase &v1, const MatrixBase &M, + //const VectorBase &v2); template void Vector::Swap(Vector *other) { @@ -1300,51 +1126,51 @@ void Vector::Swap(Vector *other) { } -template -void VectorBase::AddDiagMat2( - Real alpha, const MatrixBase &M, - MatrixTransposeType trans, Real beta) { - if (trans == kNoTrans) { - KALDI_ASSERT(this->dim_ == M.NumRows()); - MatrixIndexT rows = this->dim_, cols = M.NumCols(), - mat_stride = M.Stride(); - Real *data = this->data_; - const Real *mat_data = M.Data(); - for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) - *data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); - } else { - KALDI_ASSERT(this->dim_ == M.NumCols()); - MatrixIndexT rows = M.NumRows(), cols = this->dim_, - mat_stride = M.Stride(); - Real *data = this->data_; - const Real *mat_data = M.Data(); - for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) - *data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, - mat_data, mat_stride); - } -} - -template -void VectorBase::AddDiagMatMat( - Real alpha, - const MatrixBase &M, MatrixTransposeType transM, - const MatrixBase &N, MatrixTransposeType transN, - Real beta) { - MatrixIndexT dim = this->dim_, - M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), - N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); - KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over - MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; - if (transM == kTrans) std::swap(M_row_stride, M_col_stride); - MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; - if (transN == kTrans) std::swap(N_row_stride, N_col_stride); - - Real *data = this->data_; - const Real *Mdata = M.Data(), *Ndata = N.Data(); - for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += N_col_stride, data++) { - *data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, Ndata, N_row_stride); - } -} +//template +//void VectorBase::AddDiagMat2( + //Real alpha, const MatrixBase &M, + //MatrixTransposeType trans, Real beta) { + //if (trans == kNoTrans) { + //KALDI_ASSERT(this->dim_ == M.NumRows()); + //MatrixIndexT rows = this->dim_, cols = M.NumCols(), + //mat_stride = M.Stride(); + //Real *data = this->data_; + //const Real *mat_data = M.Data(); + //for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) + //*data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); + //} else { + //KALDI_ASSERT(this->dim_ == M.NumCols()); + //MatrixIndexT rows = M.NumRows(), cols = this->dim_, + //mat_stride = M.Stride(); + //Real *data = this->data_; + //const Real *mat_data = M.Data(); + //for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) + //*data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, + //mat_data, mat_stride); + //} +//} + +//template +//void VectorBase::AddDiagMatMat( + //Real alpha, + //const MatrixBase &M, MatrixTransposeType transM, + //const MatrixBase &N, MatrixTransposeType transN, + //Real beta) { + //MatrixIndexT dim = this->dim_, + //M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), + //N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); + //KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over + //MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; + //if (transM == kTrans) std::swap(M_row_stride, M_col_stride); + //MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; + //if (transN == kTrans) std::swap(N_row_stride, N_col_stride); + + //Real *data = this->data_; + //const Real *Mdata = M.Data(), *Ndata = N.Data(); + //for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += N_col_stride, data++) { + //*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, Ndata, N_row_stride); + //} +//} template class Vector; diff --git a/speechx/speechx/common/matrix/kaldi-vector.h b/speechx/speechx/common/matrix/kaldi-vector.h new file mode 100644 index 00000000..5bcbeda9 --- /dev/null +++ b/speechx/speechx/common/matrix/kaldi-vector.h @@ -0,0 +1,345 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real* Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real* Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator() (MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real & operator() (MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase &v); + + /// Copy data from another vector of different type (double vs. float) + template + void CopyFromVec(const VectorBase &v); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase &M); + template + void CopyRowsFromMat(const MatrixBase &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template + void CopyColFromMat(const MatrixBase &M , MatrixIndexT col); + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream &in, bool binary); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase; + friend class VectorBase; + protected: + /// Destructor; does not deallocate memory, this is handled by child classes. + /// This destructor is protected so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase(): data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// data memory area + Real* data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template +class Vector: public VectorBase { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector(): VectorBase() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { Resize(s, resize_type); } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + //template + //explicit Vector(const CuVectorBase &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector &v) : VectorBase() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template + explicit Vector(const VectorBase &v): VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + +// Took this out since it is unsafe : Arnab +// /// Constructor from a pointer and a size; copies the data to a location +// /// it owns. +// Vector(const Real* Data, const MatrixIndexT s): VectorBase() { +// Resize(s); + // CopyFromPtr(Data, s); +// } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream &in, bool binary); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator. + Vector &operator = (const Vector &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector &operator = (const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); + +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template +class SubVector : public VectorBase { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase &t, const MatrixIndexT origin, + const MatrixIndexT length) : VectorBase() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast(origin)+ + static_cast(length) <= + static_cast(t.Dim())); + VectorBase::data_ = const_cast (t.Data()+origin); + VectorBase::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + // SubVector(const PackedMatrix &M) { + //VectorBase::data_ = const_cast (M.Data()); + //VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + //} + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase () { + // this copy constructor needed for Range() to work in base class. + VectorBase::data_ = other.data_; + VectorBase::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase () { + VectorBase::data_ = const_cast(data); + VectorBase::dim_ = length; + } + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase &matrix, MatrixIndexT row) { + VectorBase::data_ = const_cast(matrix.RowData(row)); + VectorBase::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector & operator = (const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template +std::ostream & operator << (std::ostream & out, const VectorBase & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream & operator >> (std::istream & in, VectorBase & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream & operator >> (std::istream & in, Vector & v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +//template +//bool ApproxEqual(const VectorBase &a, + //const VectorBase &b, Real tol = 0.01) { + //return a.ApproxEqual(b, tol); +//} + +//template +//inline void AssertEqual(VectorBase &a, VectorBase &b, + //float tol = 0.01) { + //KALDI_ASSERT(a.ApproxEqual(b, tol)); +//} + + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-common.h b/speechx/speechx/common/matrix/matrix-common.h similarity index 78% rename from speechx/speechx/kaldi/matrix/matrix-common.h rename to speechx/speechx/common/matrix/matrix-common.h index f7047d71..b7bdbbc8 100644 --- a/speechx/speechx/kaldi/matrix/matrix-common.h +++ b/speechx/speechx/common/matrix/matrix-common.h @@ -59,26 +59,7 @@ template class SubVector; template class MatrixBase; template class SubMatrix; template class Matrix; -template class SpMatrix; -template class TpMatrix; -template class PackedMatrix; -template class SparseMatrix; - -// these are classes that won't be defined in this -// directory; they're mostly needed for friend declarations. -template class CuMatrixBase; -template class CuSubMatrix; -template class CuMatrix; -template class CuVectorBase; -template class CuSubVector; -template class CuVector; -template class CuPackedMatrix; -template class CuSpMatrix; -template class CuTpMatrix; -template class CuSparseMatrix; - -class CompressedMatrix; -class GeneralMatrix; + /// This class provides a way for switching between double and float types. template class OtherReal { }; // useful in reading+writing routines diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index d27668fc..f9b42e06 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -5,8 +5,6 @@ ${CMAKE_CURRENT_SOURCE_DIR} add_subdirectory(base) add_subdirectory(util) -add_subdirectory(feat) -add_subdirectory(matrix) add_subdirectory(lat) add_subdirectory(fstext) add_subdirectory(decoder) diff --git a/speechx/speechx/kaldi/feat/CMakeLists.txt b/speechx/speechx/kaldi/feat/CMakeLists.txt deleted file mode 100644 index cfbf2025..00000000 --- a/speechx/speechx/kaldi/feat/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_library(kaldi-mfcc - feature-mfcc.cc -) -target_link_libraries(kaldi-mfcc PUBLIC kaldi-feat-common) - -add_library(kaldi-fbank - feature-fbank.cc -) -target_link_libraries(kaldi-fbank PUBLIC kaldi-feat-common) - -add_library(kaldi-feat-common - wave-reader.cc - signal.cc - feature-functions.cc - feature-window.cc - resample.cc - mel-computations.cc - cmvn.cc -) -target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) diff --git a/speechx/speechx/kaldi/feat/cmvn.cc b/speechx/speechx/kaldi/feat/cmvn.cc deleted file mode 100644 index b2aa46e4..00000000 --- a/speechx/speechx/kaldi/feat/cmvn.cc +++ /dev/null @@ -1,183 +0,0 @@ -// transform/cmvn.cc - -// Copyright 2009-2013 Microsoft Corporation -// Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "feat/cmvn.h" - -namespace kaldi { - -void InitCmvnStats(int32 dim, Matrix *stats) { - KALDI_ASSERT(dim > 0); - stats->Resize(2, dim+1); -} - -void AccCmvnStats(const VectorBase &feats, BaseFloat weight, MatrixBase *stats) { - int32 dim = feats.Dim(); - KALDI_ASSERT(stats != NULL); - KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() == dim + 1); - // Remove these __restrict__ modifiers if they cause compilation problems. - // It's just an optimization. - double *__restrict__ mean_ptr = stats->RowData(0), - *__restrict__ var_ptr = stats->RowData(1), - *__restrict__ count_ptr = mean_ptr + dim; - const BaseFloat * __restrict__ feats_ptr = feats.Data(); - *count_ptr += weight; - // Careful-- if we change the format of the matrix, the "mean_ptr < count_ptr" - // statement below might become wrong. - for (; mean_ptr < count_ptr; mean_ptr++, var_ptr++, feats_ptr++) { - *mean_ptr += *feats_ptr * weight; - *var_ptr += *feats_ptr * *feats_ptr * weight; - } -} - -void AccCmvnStats(const MatrixBase &feats, - const VectorBase *weights, - MatrixBase *stats) { - int32 num_frames = feats.NumRows(); - if (weights != NULL) { - KALDI_ASSERT(weights->Dim() == num_frames); - } - for (int32 i = 0; i < num_frames; i++) { - SubVector this_frame = feats.Row(i); - BaseFloat weight = (weights == NULL ? 1.0 : (*weights)(i)); - if (weight != 0.0) - AccCmvnStats(this_frame, weight, stats); - } -} - -void ApplyCmvn(const MatrixBase &stats, - bool var_norm, - MatrixBase *feats) { - KALDI_ASSERT(feats != NULL); - int32 dim = stats.NumCols() - 1; - if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { - KALDI_ERR << "Dim mismatch: cmvn " - << stats.NumRows() << 'x' << stats.NumCols() - << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); - } - if (stats.NumRows() == 1 && var_norm) - KALDI_ERR << "You requested variance normalization but no variance stats " - << "are supplied."; - - double count = stats(0, dim); - // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when - // computing an offset and representing it as stats, we use a count of one. - if (count < 1.0) - KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " - << "count = " << count; - - if (!var_norm) { - Vector offset(dim); - SubVector mean_stats(stats.RowData(0), dim); - offset.AddVec(-1.0 / count, mean_stats); - feats->AddVecToRows(1.0, offset); - return; - } - // norm(0, d) = mean offset; - // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). - Matrix norm(2, dim); - for (int32 d = 0; d < dim; d++) { - double mean, offset, scale; - mean = stats(0, d)/count; - double var = (stats(1, d)/count) - mean*mean, - floor = 1.0e-20; - if (var < floor) { - KALDI_WARN << "Flooring cepstral variance from " << var << " to " - << floor; - var = floor; - } - scale = 1.0 / sqrt(var); - if (scale != scale || 1/scale == 0.0) - KALDI_ERR << "NaN or infinity in cepstral mean/variance computation"; - offset = -(mean*scale); - norm(0, d) = offset; - norm(1, d) = scale; - } - // Apply the normalization. - feats->MulColsVec(norm.Row(1)); - feats->AddVecToRows(1.0, norm.Row(0)); -} - -void ApplyCmvnReverse(const MatrixBase &stats, - bool var_norm, - MatrixBase *feats) { - KALDI_ASSERT(feats != NULL); - int32 dim = stats.NumCols() - 1; - if (stats.NumRows() > 2 || stats.NumRows() < 1 || feats->NumCols() != dim) { - KALDI_ERR << "Dim mismatch: cmvn " - << stats.NumRows() << 'x' << stats.NumCols() - << ", feats " << feats->NumRows() << 'x' << feats->NumCols(); - } - if (stats.NumRows() == 1 && var_norm) - KALDI_ERR << "You requested variance normalization but no variance stats " - << "are supplied."; - - double count = stats(0, dim); - // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when - // computing an offset and representing it as stats, we use a count of one. - if (count < 1.0) - KALDI_ERR << "Insufficient stats for cepstral mean and variance normalization: " - << "count = " << count; - - Matrix norm(2, dim); // norm(0, d) = mean offset - // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). - for (int32 d = 0; d < dim; d++) { - double mean, offset, scale; - mean = stats(0, d) / count; - if (!var_norm) { - scale = 1.0; - offset = mean; - } else { - double var = (stats(1, d)/count) - mean*mean, - floor = 1.0e-20; - if (var < floor) { - KALDI_WARN << "Flooring cepstral variance from " << var << " to " - << floor; - var = floor; - } - // we aim to transform zero-mean, unit-variance input into data - // with the given mean and variance. - scale = sqrt(var); - offset = mean; - } - norm(0, d) = offset; - norm(1, d) = scale; - } - if (var_norm) - feats->MulColsVec(norm.Row(1)); - feats->AddVecToRows(1.0, norm.Row(0)); -} - - -void FakeStatsForSomeDims(const std::vector &dims, - MatrixBase *stats) { - KALDI_ASSERT(stats->NumRows() == 2 && stats->NumCols() > 1); - int32 dim = stats->NumCols() - 1; - double count = (*stats)(0, dim); - for (size_t i = 0; i < dims.size(); i++) { - int32 d = dims[i]; - KALDI_ASSERT(d >= 0 && d < dim); - (*stats)(0, d) = 0.0; - (*stats)(1, d) = count; - } -} - - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/cmvn.h b/speechx/speechx/kaldi/feat/cmvn.h deleted file mode 100644 index c6d1b7f7..00000000 --- a/speechx/speechx/kaldi/feat/cmvn.h +++ /dev/null @@ -1,75 +0,0 @@ -// transform/cmvn.h - -// Copyright 2009-2013 Microsoft Corporation -// Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#ifndef KALDI_TRANSFORM_CMVN_H_ -#define KALDI_TRANSFORM_CMVN_H_ - -#include "base/kaldi-common.h" -#include "matrix/matrix-lib.h" - -namespace kaldi { - -/// This function initializes the matrix to dimension 2 by (dim+1); -/// 1st "dim" elements of 1st row are mean stats, 1st "dim" elements -/// of 2nd row are var stats, last element of 1st row is count, -/// last element of 2nd row is zero. -void InitCmvnStats(int32 dim, Matrix *stats); - -/// Accumulation from a single frame (weighted). -void AccCmvnStats(const VectorBase &feat, - BaseFloat weight, - MatrixBase *stats); - -/// Accumulation from a feature file (possibly weighted-- useful in excluding silence). -void AccCmvnStats(const MatrixBase &feats, - const VectorBase *weights, // or NULL - MatrixBase *stats); - -/// Apply cepstral mean and variance normalization to a matrix of features. -/// If norm_vars == true, expects stats to be of dimension 2 by (dim+1), but -/// if norm_vars == false, will accept stats of dimension 1 by (dim+1); these -/// are produced by the balanced-cmvn code when it computes an offset and -/// represents it as "fake stats". -void ApplyCmvn(const MatrixBase &stats, - bool norm_vars, - MatrixBase *feats); - -/// This is as ApplyCmvn, but does so in the reverse sense, i.e. applies a transform -/// that would take zero-mean, unit-variance input and turn it into output with the -/// stats of "stats". This can be useful if you trained without CMVN but later want -/// to correct a mismatch, so you would first apply CMVN and then do the "reverse" -/// CMVN with the summed stats of your training data. -void ApplyCmvnReverse(const MatrixBase &stats, - bool norm_vars, - MatrixBase *feats); - - -/// Modify the stats so that for some dimensions (specified in "dims"), we -/// replace them with "fake" stats that have zero mean and unit variance; this -/// is done to disable CMVN for those dimensions. -void FakeStatsForSomeDims(const std::vector &dims, - MatrixBase *stats); - - - -} // namespace kaldi - -#endif // KALDI_TRANSFORM_CMVN_H_ diff --git a/speechx/speechx/kaldi/feat/feature-common-inl.h b/speechx/speechx/kaldi/feat/feature-common-inl.h deleted file mode 100644 index 26127a4d..00000000 --- a/speechx/speechx/kaldi/feat/feature-common-inl.h +++ /dev/null @@ -1,99 +0,0 @@ -// feat/feature-common-inl.h - -// Copyright 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_COMMON_INL_H_ -#define KALDI_FEAT_FEATURE_COMMON_INL_H_ - -#include "feat/resample.h" -// Do not include this file directly. It is included by feat/feature-common.h - -namespace kaldi { - -template -void OfflineFeatureTpl::ComputeFeatures( - const VectorBase &wave, - BaseFloat sample_freq, - BaseFloat vtln_warp, - Matrix *output) { - KALDI_ASSERT(output != NULL); - BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq; - if (sample_freq == new_sample_freq) { - Compute(wave, vtln_warp, output); - } else { - if (new_sample_freq < sample_freq && - ! computer_.GetFrameOptions().allow_downsample) - KALDI_ERR << "Waveform and config sample Frequency mismatch: " - << sample_freq << " .vs " << new_sample_freq - << " (use --allow-downsample=true to allow " - << " downsampling the waveform)."; - else if (new_sample_freq > sample_freq && - ! computer_.GetFrameOptions().allow_upsample) - KALDI_ERR << "Waveform and config sample Frequency mismatch: " - << sample_freq << " .vs " << new_sample_freq - << " (use --allow-upsample=true option to allow " - << " upsampling the waveform)."; - // Resample the waveform. - Vector resampled_wave(wave); - ResampleWaveform(sample_freq, wave, - new_sample_freq, &resampled_wave); - Compute(resampled_wave, vtln_warp, output); - } -} - -template -void OfflineFeatureTpl::Compute( - const VectorBase &wave, - BaseFloat vtln_warp, - Matrix *output) { - KALDI_ASSERT(output != NULL); - int32 rows_out = NumFrames(wave.Dim(), computer_.GetFrameOptions()), - cols_out = computer_.Dim(); - if (rows_out == 0) { - output->Resize(0, 0); - return; - } - output->Resize(rows_out, cols_out); - Vector window; // windowed waveform. - bool use_raw_log_energy = computer_.NeedRawLogEnergy(); - for (int32 r = 0; r < rows_out; r++) { // r is frame index. - BaseFloat raw_log_energy = 0.0; - ExtractWindow(0, wave, r, computer_.GetFrameOptions(), - feature_window_function_, &window, - (use_raw_log_energy ? &raw_log_energy : NULL)); - - SubVector output_row(*output, r); - computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); - } -} - -template -void OfflineFeatureTpl::Compute( - const VectorBase &wave, - BaseFloat vtln_warp, - Matrix *output) const { - OfflineFeatureTpl temp(*this); - // call the non-const version of Compute() on a temporary copy of this object. - // This is a workaround for const-ness that may sometimes be useful in - // multi-threaded code, although it's not optimally efficient. - temp.Compute(wave, vtln_warp, output); -} - -} // end namespace kaldi - -#endif diff --git a/speechx/speechx/kaldi/feat/feature-common.h b/speechx/speechx/kaldi/feat/feature-common.h deleted file mode 100644 index 3c2fbd37..00000000 --- a/speechx/speechx/kaldi/feat/feature-common.h +++ /dev/null @@ -1,176 +0,0 @@ -// feat/feature-common.h - -// Copyright 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABILITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_COMMON_H_ -#define KALDI_FEAT_FEATURE_COMMON_H_ - -#include -#include -#include "feat/feature-window.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - - -/// This class is only added for documentation, it is not intended to ever be -/// used. -struct ExampleFeatureComputerOptions { - FrameExtractionOptions frame_opts; - // .. more would go here. -}; - -/// This class is only added for documentation, it is not intended to ever be -/// used. It documents the interface of the *Computer classes which wrap the -/// low-level feature extraction. The template argument F of OfflineFeatureTpl must -/// follow this interface. This interface is intended for features such as -/// MFCCs and PLPs which can be computed frame by frame. -class ExampleFeatureComputer { - public: - typedef ExampleFeatureComputerOptions Options; - - /// Returns a reference to the frame-extraction options class, which - /// will be part of our own options class. - const FrameExtractionOptions &GetFrameOptions() const { - return opts_.frame_opts; - } - - /// Returns the feature dimension - int32 Dim() const; - - /// Returns true if this function may inspect the raw log-energy of the signal - /// (before windowing and pre-emphasis); it's safe to always return true, but - /// setting it to false enables an optimization. - bool NeedRawLogEnergy() const { return true; } - - /// constructor from options class; it should not store a reference or pointer - /// to the options class but should copy it. - explicit ExampleFeatureComputer(const ExampleFeatureComputerOptions &opts): - opts_(opts) { } - - /// Copy constructor; all of these classes must have one. - ExampleFeatureComputer(const ExampleFeatureComputer &other); - - /** - Function that computes one frame of features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedRawLogEnergy(). - @param [in] vtln_warp The VTLN warping factor that the user wants - to be applied when computing features for this utterance. Will - normally be 1.0, meaning no warping is to be done. The value will - be ignored for feature types that don't support VLTN, such as - spectrogram features. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - private: - // disallow assignment. - ExampleFeatureComputer &operator = (const ExampleFeatureComputer &in); - Options opts_; -}; - - -/// This templated class is intended for offline feature extraction, i.e. where -/// you have access to the entire signal at the start. It exists mainly to be -/// drop-in replacement for the old (pre-2016) classes Mfcc, Plp and so on, for -/// use in the offline case. In April 2016 we reorganized the online -/// feature-computation code for greater modularity and to have correct support -/// for the snip-edges=false option. -template -class OfflineFeatureTpl { - public: - typedef typename F::Options Options; - - // Note: feature_window_function_ is the windowing function, which initialized - // using the options class, that we cache at this level. - OfflineFeatureTpl(const Options &opts): - computer_(opts), - feature_window_function_(computer_.GetFrameOptions()) { } - - // Internal (and back-compatibility) interface for computing features, which - // requires that the user has already checked that the sampling frequency - // of the waveform is equal to the sampling frequency specified in - // the frame-extraction options. - void Compute(const VectorBase &wave, - BaseFloat vtln_warp, - Matrix *output); - - // This const version of Compute() is a wrapper that - // calls the non-const version on a temporary object. - // It's less efficient than the non-const version. - void Compute(const VectorBase &wave, - BaseFloat vtln_warp, - Matrix *output) const; - - /** - Computes the features for one file (one sequence of features). - This is the newer interface where you specify the sample frequency - of the input waveform. - @param [in] wave The input waveform - @param [in] sample_freq The sampling frequency with which - 'wave' was sampled. - if sample_freq is higher than the frequency - specified in the config, we will downsample - the waveform, but if lower, it's an error. - @param [in] vtln_warp The VTLN warping factor (will normally - be 1.0) - @param [out] output The matrix of features, where the row-index - is the frame index. - */ - void ComputeFeatures(const VectorBase &wave, - BaseFloat sample_freq, - BaseFloat vtln_warp, - Matrix *output); - - int32 Dim() const { return computer_.Dim(); } - - // Copy constructor. - OfflineFeatureTpl(const OfflineFeatureTpl &other): - computer_(other.computer_), - feature_window_function_(other.feature_window_function_) { } - private: - // Disallow assignment. - OfflineFeatureTpl &operator =(const OfflineFeatureTpl &other); - - F computer_; - FeatureWindowFunction feature_window_function_; -}; - -/// @} End of "addtogroup feat" -} // namespace kaldi - - -#include "feat/feature-common-inl.h" - -#endif // KALDI_FEAT_FEATURE_COMMON_H_ diff --git a/speechx/speechx/kaldi/feat/feature-fbank.cc b/speechx/speechx/kaldi/feat/feature-fbank.cc deleted file mode 100644 index d9ac03e5..00000000 --- a/speechx/speechx/kaldi/feat/feature-fbank.cc +++ /dev/null @@ -1,125 +0,0 @@ -// feat/feature-fbank.cc - -// Copyright 2009-2012 Karel Vesely -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-fbank.h" - -namespace kaldi { - -FbankComputer::FbankComputer(const FbankOptions &opts): - opts_(opts), srfft_(NULL) { - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); - - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); - - // We'll definitely need the filterbanks info for VTLN warping factor 1.0. - // [note: this call caches it.] - GetMelBanks(1.0); -} - -FbankComputer::FbankComputer(const FbankComputer &other): - opts_(other.opts_), log_energy_floor_(other.log_energy_floor_), - mel_banks_(other.mel_banks_), srfft_(NULL) { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); - ++iter) - iter->second = new MelBanks(*(iter->second)); - if (other.srfft_) - srfft_ = new SplitRadixRealFft(*(other.srfft_)); -} - -FbankComputer::~FbankComputer() { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - delete iter->second; - delete srfft_; -} - -const MelBanks* FbankComputer::GetMelBanks(BaseFloat vtln_warp) { - MelBanks *this_mel_banks = NULL; - std::map::iterator iter = mel_banks_.find(vtln_warp); - if (iter == mel_banks_.end()) { - this_mel_banks = new MelBanks(opts_.mel_opts, - opts_.frame_opts, - vtln_warp); - mel_banks_[vtln_warp] = this_mel_banks; - } else { - this_mel_banks = iter->second; - } - return this_mel_banks; -} - -void FbankComputer::Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature) { - - const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); - - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && - feature->Dim() == this->Dim()); - - - // Compute energy after window function (not the raw one). - if (opts_.use_energy && !opts_.raw_energy) - signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::epsilon())); - - if (srfft_ != NULL) // Compute FFT using split-radix algorithm. - srfft_->Compute(signal_frame->Data(), true); - else // An alternative algorithm that works for non-powers-of-two. - RealFft(signal_frame, true); - - // Convert the FFT into a power spectrum. - ComputePowerSpectrum(signal_frame); - SubVector power_spectrum(*signal_frame, 0, - signal_frame->Dim() / 2 + 1); - - // Use magnitude instead of power if requested. - if (!opts_.use_power) - power_spectrum.ApplyPow(0.5); - - int32 mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); - SubVector mel_energies(*feature, - mel_offset, - opts_.mel_opts.num_bins); - - // Sum with mel fiterbanks over the power spectrum - mel_banks.Compute(power_spectrum, &mel_energies); - if (opts_.use_log_fbank) { - // Avoid log of zero (which should be prevented anyway by dithering). - mel_energies.ApplyFloor(std::numeric_limits::epsilon()); - mel_energies.ApplyLog(); // take the log. - } - - // Copy energy as first value (or the last, if htk_compat == true). - if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) { - signal_raw_log_energy = log_energy_floor_; - } - int32 energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; - (*feature)(energy_index) = signal_raw_log_energy; - } -} - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-fbank.h b/speechx/speechx/kaldi/feat/feature-fbank.h deleted file mode 100644 index d121cc0e..00000000 --- a/speechx/speechx/kaldi/feat/feature-fbank.h +++ /dev/null @@ -1,149 +0,0 @@ -// feat/feature-fbank.h - -// Copyright 2009-2012 Karel Vesely -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_FBANK_H_ -#define KALDI_FEAT_FEATURE_FBANK_H_ - -#include -#include - -#include "feat/feature-common.h" -#include "feat/feature-functions.h" -#include "feat/feature-window.h" -#include "feat/mel-computations.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - -/// FbankOptions contains basic options for computing filterbank features. -/// It only includes things that can be done in a "stateless" way, i.e. -/// it does not include energy max-normalization. -/// It does not include delta computation. -struct FbankOptions { - FrameExtractionOptions frame_opts; - MelBanksOptions mel_opts; - bool use_energy; // append an extra dimension with energy to the filter banks - BaseFloat energy_floor; - bool raw_energy; // If true, compute energy before preemphasis and windowing - bool htk_compat; // If true, put energy last (if using energy) - bool use_log_fbank; // if true (default), produce log-filterbank, else linear - bool use_power; // if true (default), use power in filterbank analysis, else magnitude. - - FbankOptions(): mel_opts(23), - // defaults the #mel-banks to 23 for the FBANK computations. - // this seems to be common for 16khz-sampled data, - // but for 8khz-sampled data, 15 may be better. - use_energy(false), - energy_floor(0.0), - raw_energy(true), - htk_compat(false), - use_log_fbank(true), - use_power(true) {} - - void Register(OptionsItf *opts) { - frame_opts.Register(opts); - mel_opts.Register(opts); - opts->Register("use-energy", &use_energy, - "Add an extra dimension with energy to the FBANK output."); - opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in FBANK computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("htk-compat", &htk_compat, "If true, put energy last. " - "Warning: not sufficient to get HTK compatible features (need " - "to change other parameters)."); - opts->Register("use-log-fbank", &use_log_fbank, - "If true, produce log-filterbank, else produce linear."); - opts->Register("use-power", &use_power, - "If true, use power, else use magnitude."); - } -}; - - -/// Class for computing mel-filterbank features; see \ref feat_mfcc for more -/// information. -class FbankComputer { - public: - typedef FbankOptions Options; - - explicit FbankComputer(const FbankOptions &opts); - FbankComputer(const FbankComputer &other); - - int32 Dim() const { - return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); - } - - bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } - - const FrameExtractionOptions &GetFrameOptions() const { - return opts_.frame_opts; - } - - /** - Function that computes one frame of features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). - @param [in] vtln_warp The VTLN warping factor that the user wants - to be applied when computing features for this utterance. Will - normally be 1.0, meaning no warping is to be done. The value will - be ignored for feature types that don't support VLTN, such as - spectrogram features. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - ~FbankComputer(); - - const MelBanks *GetMelBanks(BaseFloat vtln_warp); - private: - - - FbankOptions opts_; - BaseFloat log_energy_floor_; - std::map mel_banks_; // BaseFloat is VTLN coefficient. - SplitRadixRealFft *srfft_; - // Disallow assignment. - FbankComputer &operator =(const FbankComputer &other); -}; - -typedef OfflineFeatureTpl Fbank; - -/// @} End of "addtogroup feat" -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_FBANK_H_ diff --git a/speechx/speechx/kaldi/feat/feature-functions.cc b/speechx/speechx/kaldi/feat/feature-functions.cc deleted file mode 100644 index 76500ccf..00000000 --- a/speechx/speechx/kaldi/feat/feature-functions.cc +++ /dev/null @@ -1,362 +0,0 @@ -// feat/feature-functions.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation -// 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 IMSL, PKU-HKUST (author: Wei Shi) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-functions.h" -#include "matrix/matrix-functions.h" - - -namespace kaldi { - -void ComputePowerSpectrum(VectorBase *waveform) { - int32 dim = waveform->Dim(); - - // no, letting it be non-power-of-two for now. - // KALDI_ASSERT(dim > 0 && (dim & (dim-1) == 0)); // make sure a power of two.. actually my FFT code - // does not require this (dan) but this is better in case we use different code [dan]. - - // RealFft(waveform, true); // true == forward (not inverse) FFT; makes no difference here, - // as we just want power spectrum. - - // now we have in waveform, first half of complex spectrum - // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] - int32 half_dim = dim/2; - BaseFloat first_energy = (*waveform)(0) * (*waveform)(0), - last_energy = (*waveform)(1) * (*waveform)(1); // handle this special case - for (int32 i = 1; i < half_dim; i++) { - BaseFloat real = (*waveform)(i*2), im = (*waveform)(i*2 + 1); - (*waveform)(i) = real*real + im*im; - } - (*waveform)(0) = first_energy; - (*waveform)(half_dim) = last_energy; // Will actually never be used, and anyway - // if the signal has been bandlimited sensibly this should be zero. -} - - -DeltaFeatures::DeltaFeatures(const DeltaFeaturesOptions &opts): opts_(opts) { - KALDI_ASSERT(opts.order >= 0 && opts.order < 1000); // just make sure we don't get binary junk. - // opts will normally be 2 or 3. - KALDI_ASSERT(opts.window > 0 && opts.window < 1000); // again, basic sanity check. - // normally the window size will be two. - - scales_.resize(opts.order+1); - scales_[0].Resize(1); - scales_[0](0) = 1.0; // trivial window for 0th order delta [i.e. baseline feats] - - for (int32 i = 1; i <= opts.order; i++) { - Vector &prev_scales = scales_[i-1], - &cur_scales = scales_[i]; - int32 window = opts.window; // this code is designed to still - // work if instead we later make it an array and do opts.window[i-1], - // or something like that. "window" is a parameter specifying delta-window - // width which is actually 2*window + 1. - KALDI_ASSERT(window != 0); - int32 prev_offset = (static_cast(prev_scales.Dim()-1))/2, - cur_offset = prev_offset + window; - cur_scales.Resize(prev_scales.Dim() + 2*window); // also zeros it. - - BaseFloat normalizer = 0.0; - for (int32 j = -window; j <= window; j++) { - normalizer += j*j; - for (int32 k = -prev_offset; k <= prev_offset; k++) { - cur_scales(j+k+cur_offset) += - static_cast(j) * prev_scales(k+prev_offset); - } - } - cur_scales.Scale(1.0 / normalizer); - } -} - -void DeltaFeatures::Process(const MatrixBase &input_feats, - int32 frame, - VectorBase *output_frame) const { - KALDI_ASSERT(frame < input_feats.NumRows()); - int32 num_frames = input_feats.NumRows(), - feat_dim = input_feats.NumCols(); - KALDI_ASSERT(static_cast(output_frame->Dim()) == feat_dim * (opts_.order+1)); - output_frame->SetZero(); - for (int32 i = 0; i <= opts_.order; i++) { - const Vector &scales = scales_[i]; - int32 max_offset = (scales.Dim() - 1) / 2; - SubVector output(*output_frame, i*feat_dim, feat_dim); - for (int32 j = -max_offset; j <= max_offset; j++) { - // if asked to read - int32 offset_frame = frame + j; - if (offset_frame < 0) offset_frame = 0; - else if (offset_frame >= num_frames) - offset_frame = num_frames - 1; - BaseFloat scale = scales(j + max_offset); - if (scale != 0.0) - output.AddVec(scale, input_feats.Row(offset_frame)); - } - } -} - -ShiftedDeltaFeatures::ShiftedDeltaFeatures( - const ShiftedDeltaFeaturesOptions &opts): opts_(opts) { - KALDI_ASSERT(opts.window > 0 && opts.window < 1000); - - // Default window is 1. - int32 window = opts.window; - KALDI_ASSERT(window != 0); - scales_.Resize(1 + 2*window); // also zeros it. - BaseFloat normalizer = 0.0; - for (int32 j = -window; j <= window; j++) { - normalizer += j*j; - scales_(j + window) += static_cast(j); - } - scales_.Scale(1.0 / normalizer); -} - -void ShiftedDeltaFeatures::Process(const MatrixBase &input_feats, - int32 frame, - SubVector *output_frame) const { - KALDI_ASSERT(frame < input_feats.NumRows()); - int32 num_frames = input_feats.NumRows(), - feat_dim = input_feats.NumCols(); - KALDI_ASSERT(static_cast(output_frame->Dim()) - == feat_dim * (opts_.num_blocks + 1)); - output_frame->SetZero(); - - // The original features - SubVector output(*output_frame, 0, feat_dim); - output.AddVec(1.0, input_feats.Row(frame)); - - // Concatenate the delta-blocks. Each block is block_shift - // (usually 3) frames apart. - for (int32 i = 0; i < opts_.num_blocks; i++) { - int32 max_offset = (scales_.Dim() - 1) / 2; - SubVector output(*output_frame, (i + 1) * feat_dim, feat_dim); - for (int32 j = -max_offset; j <= max_offset; j++) { - int32 offset_frame = frame + j + i * opts_.block_shift; - if (offset_frame < 0) offset_frame = 0; - else if (offset_frame >= num_frames) - offset_frame = num_frames - 1; - BaseFloat scale = scales_(j + max_offset); - if (scale != 0.0) - output.AddVec(scale, input_feats.Row(offset_frame)); - } - } -} - -void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, - const MatrixBase &input_features, - Matrix *output_features) { - output_features->Resize(input_features.NumRows(), - input_features.NumCols() - *(delta_opts.order + 1)); - DeltaFeatures delta(delta_opts); - for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { - SubVector row(*output_features, r); - delta.Process(input_features, r, &row); - } -} - -void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, - const MatrixBase &input_features, - Matrix *output_features) { - output_features->Resize(input_features.NumRows(), - input_features.NumCols() - * (delta_opts.num_blocks + 1)); - ShiftedDeltaFeatures delta(delta_opts); - - for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { - SubVector row(*output_features, r); - delta.Process(input_features, r, &row); - } -} - - -void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out) { - BaseFloat angle = M_PI / static_cast(dimension - 1); - BaseFloat scale = 1.0f / (2.0 * static_cast(dimension - 1)); - mat_out->Resize(n_bases, dimension); - for (int32 i = 0; i < n_bases; i++) { - (*mat_out)(i, 0) = 1.0 * scale; - BaseFloat i_fl = static_cast(i); - for (int32 j = 1; j < dimension - 1; j++) { - BaseFloat j_fl = static_cast(j); - (*mat_out)(i, j) = 2.0 * scale * cos(angle * i_fl * j_fl); - } - - (*mat_out)(i, dimension -1) - = scale * cos(angle * i_fl * static_cast(dimension-1)); - } -} - -void SpliceFrames(const MatrixBase &input_features, - int32 left_context, - int32 right_context, - Matrix *output_features) { - int32 T = input_features.NumRows(), D = input_features.NumCols(); - if (T == 0 || D == 0) - KALDI_ERR << "SpliceFrames: empty input"; - KALDI_ASSERT(left_context >= 0 && right_context >= 0); - int32 N = 1 + left_context + right_context; - output_features->Resize(T, D*N); - for (int32 t = 0; t < T; t++) { - SubVector dst_row(*output_features, t); - for (int32 j = 0; j < N; j++) { - int32 t2 = t + j - left_context; - if (t2 < 0) t2 = 0; - if (t2 >= T) t2 = T-1; - SubVector dst(dst_row, j*D, D), - src(input_features, t2); - dst.CopyFromVec(src); - } - } -} - -void ReverseFrames(const MatrixBase &input_features, - Matrix *output_features) { - int32 T = input_features.NumRows(), D = input_features.NumCols(); - if (T == 0 || D == 0) - KALDI_ERR << "ReverseFrames: empty input"; - output_features->Resize(T, D); - for (int32 t = 0; t < T; t++) { - SubVector dst_row(*output_features, t); - SubVector src_row(input_features, T-1-t); - dst_row.CopyFromVec(src_row); - } -} - - -void SlidingWindowCmnOptions::Check() const { - KALDI_ASSERT(cmn_window > 0); - if (center) - KALDI_ASSERT(min_window > 0 && min_window <= cmn_window); - // else ignored so value doesn't matter. -} - -// Internal version of SlidingWindowCmn with double-precision arguments. -void SlidingWindowCmnInternal(const SlidingWindowCmnOptions &opts, - const MatrixBase &input, - MatrixBase *output) { - opts.Check(); - int32 num_frames = input.NumRows(), dim = input.NumCols(), - last_window_start = -1, last_window_end = -1, - warning_count = 0; - Vector cur_sum(dim), cur_sumsq(dim); - - for (int32 t = 0; t < num_frames; t++) { - int32 window_start, window_end; // note: window_end will be one - // past the end of the window we use for normalization. - if (opts.center) { - window_start = t - (opts.cmn_window / 2); - window_end = window_start + opts.cmn_window; - } else { - window_start = t - opts.cmn_window; - window_end = t + 1; - } - if (window_start < 0) { // shift window right if starts <0. - window_end -= window_start; - window_start = 0; // or: window_start -= window_start - } - if (!opts.center) { - if (window_end > t) - window_end = std::max(t + 1, opts.min_window); - } - if (window_end > num_frames) { - window_start -= (window_end - num_frames); - window_end = num_frames; - if (window_start < 0) window_start = 0; - } - if (last_window_start == -1) { - SubMatrix input_part(input, - window_start, window_end - window_start, - 0, dim); - cur_sum.AddRowSumMat(1.0, input_part , 0.0); - if (opts.normalize_variance) - cur_sumsq.AddDiagMat2(1.0, input_part, kTrans, 0.0); - } else { - if (window_start > last_window_start) { - KALDI_ASSERT(window_start == last_window_start + 1); - SubVector frame_to_remove(input, last_window_start); - cur_sum.AddVec(-1.0, frame_to_remove); - if (opts.normalize_variance) - cur_sumsq.AddVec2(-1.0, frame_to_remove); - } - if (window_end > last_window_end) { - KALDI_ASSERT(window_end == last_window_end + 1); - SubVector frame_to_add(input, last_window_end); - cur_sum.AddVec(1.0, frame_to_add); - if (opts.normalize_variance) - cur_sumsq.AddVec2(1.0, frame_to_add); - } - } - int32 window_frames = window_end - window_start; - last_window_start = window_start; - last_window_end = window_end; - - KALDI_ASSERT(window_frames > 0); - SubVector input_frame(input, t), - output_frame(*output, t); - output_frame.CopyFromVec(input_frame); - output_frame.AddVec(-1.0 / window_frames, cur_sum); - - if (opts.normalize_variance) { - if (window_frames == 1) { - output_frame.Set(0.0); - } else { - Vector variance(cur_sumsq); - variance.Scale(1.0 / window_frames); - variance.AddVec2(-1.0 / (window_frames * window_frames), cur_sum); - // now "variance" is the variance of the features in the window, - // around their own mean. - int32 num_floored; - variance.ApplyFloor(1.0e-10, &num_floored); - if (num_floored > 0 && num_frames > 1) { - if (opts.max_warnings == warning_count) { - KALDI_WARN << "Suppressing the remaining variance flooring " - << "warnings. Run program with --max-warnings=-1 to " - << "see all warnings."; - } - // If opts.max_warnings is a negative number, we won't restrict the - // number of times that the warning is printed out. - else if (opts.max_warnings < 0 - || opts.max_warnings > warning_count) { - KALDI_WARN << "Flooring when normalizing variance, floored " - << num_floored << " elements; num-frames was " - << window_frames; - } - warning_count++; - } - variance.ApplyPow(-0.5); // get inverse standard deviation. - output_frame.MulElements(variance); - } - } - } -} - - -void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, - const MatrixBase &input, - MatrixBase *output) { - KALDI_ASSERT(SameDim(input, *output) && input.NumRows() > 0); - Matrix input_dbl(input), output_dbl(input.NumRows(), input.NumCols()); - // call double-precision version - SlidingWindowCmnInternal(opts, input_dbl, &output_dbl); - output->CopyFromMat(output_dbl); -} - - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-functions.h b/speechx/speechx/kaldi/feat/feature-functions.h deleted file mode 100644 index 52454f30..00000000 --- a/speechx/speechx/kaldi/feat/feature-functions.h +++ /dev/null @@ -1,204 +0,0 @@ -// feat/feature-functions.h - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation -// 2014 IMSL, PKU-HKUST (author: Wei Shi) -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#ifndef KALDI_FEAT_FEATURE_FUNCTIONS_H_ -#define KALDI_FEAT_FEATURE_FUNCTIONS_H_ - -#include -#include - -#include "matrix/matrix-lib.h" -#include "util/common-utils.h" -#include "base/kaldi-error.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - -// ComputePowerSpectrum converts a complex FFT (as produced by the FFT -// functions in matrix/matrix-functions.h), and converts it into -// a power spectrum. If the complex FFT is a vector of size n (representing -// half the complex FFT of a real signal of size n, as described there), -// this function computes in the first (n/2) + 1 elements of it, the -// energies of the fft bins from zero to the Nyquist frequency. Contents of the -// remaining (n/2) - 1 elements are undefined at output. -void ComputePowerSpectrum(VectorBase *complex_fft); - - -struct DeltaFeaturesOptions { - int32 order; - int32 window; // e.g. 2; controls window size (window size is 2*window + 1) - // the behavior at the edges is to replicate the first or last frame. - // this is not configurable. - - DeltaFeaturesOptions(int32 order = 2, int32 window = 2): - order(order), window(window) { } - void Register(OptionsItf *opts) { - opts->Register("delta-order", &order, "Order of delta computation"); - opts->Register("delta-window", &window, - "Parameter controlling window for delta computation (actual window" - " size for each delta order is 1 + 2*delta-window-size)"); - } -}; - -class DeltaFeatures { - public: - // This class provides a low-level function to compute delta features. - // The function takes as input a matrix of features and a frame index - // that it should compute the deltas on. It puts its output in an object - // of type VectorBase, of size (original-feature-dimension) * (opts.order+1). - // This is not the most efficient way to do the computation, but it's - // state-free and thus easier to understand - - explicit DeltaFeatures(const DeltaFeaturesOptions &opts); - - void Process(const MatrixBase &input_feats, - int32 frame, - VectorBase *output_frame) const; - private: - DeltaFeaturesOptions opts_; - std::vector > scales_; // a scaling window for each - // of the orders, including zero: multiply the features for each - // dimension by this window. -}; - -struct ShiftedDeltaFeaturesOptions { - int32 window, // The time delay and advance - num_blocks, - block_shift; // Distance between consecutive blocks - - ShiftedDeltaFeaturesOptions(): - window(1), num_blocks(7), block_shift(3) { } - void Register(OptionsItf *opts) { - opts->Register("delta-window", &window, "Size of delta advance and delay."); - opts->Register("num-blocks", &num_blocks, "Number of delta blocks in advance" - " of each frame to be concatenated"); - opts->Register("block-shift", &block_shift, "Distance between each block"); - } -}; - -class ShiftedDeltaFeatures { - public: - // This class provides a low-level function to compute shifted - // delta cesptra (SDC). - // The function takes as input a matrix of features and a frame index - // that it should compute the deltas on. It puts its output in an object - // of type VectorBase, of size original-feature-dimension + (1 * num_blocks). - - explicit ShiftedDeltaFeatures(const ShiftedDeltaFeaturesOptions &opts); - - void Process(const MatrixBase &input_feats, - int32 frame, - SubVector *output_frame) const; - private: - ShiftedDeltaFeaturesOptions opts_; - Vector scales_; // a scaling window for each - -}; - -// ComputeDeltas is a convenience function that computes deltas on a feature -// file. If you want to deal with features coming in bit by bit you would have -// to use the DeltaFeatures class directly, and do the computation frame by -// frame. Later we will have to come up with a nice mechanism to do this for -// features coming in. -void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, - const MatrixBase &input_features, - Matrix *output_features); - -// ComputeShiftedDeltas computes deltas from a feature file by applying -// ShiftedDeltaFeatures over the frames. This function is provided for -// convenience, however, ShiftedDeltaFeatures can be used directly. -void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, - const MatrixBase &input_features, - Matrix *output_features); - -// SpliceFrames will normally be used together with LDA. -// It splices frames together to make a window. At the -// start and end of an utterance, it duplicates the first -// and last frames. -// Will throw if input features are empty. -// left_context and right_context must be nonnegative. -// these both represent a number of frames (e.g. 4, 4 is -// a good choice). -void SpliceFrames(const MatrixBase &input_features, - int32 left_context, - int32 right_context, - Matrix *output_features); - -// ReverseFrames reverses the frames in time (used for backwards decoding) -void ReverseFrames(const MatrixBase &input_features, - Matrix *output_features); - - -void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out); - - -// This is used for speaker-id. Also see OnlineCmnOptions in ../online2/, which -// is online CMN with no latency, for online speech recognition. -struct SlidingWindowCmnOptions { - int32 cmn_window; - int32 min_window; - int32 max_warnings; - bool normalize_variance; - bool center; - - SlidingWindowCmnOptions(): - cmn_window(600), - min_window(100), - max_warnings(5), - normalize_variance(false), - center(false) { } - - void Register(OptionsItf *opts) { - opts->Register("cmn-window", &cmn_window, "Window in frames for running " - "average CMN computation"); - opts->Register("min-cmn-window", &min_window, "Minimum CMN window " - "used at start of decoding (adds latency only at start). " - "Only applicable if center == false, ignored if center==true"); - opts->Register("max-warnings", &max_warnings, "Maximum warnings to report " - "per utterance. 0 to disable, -1 to show all."); - opts->Register("norm-vars", &normalize_variance, "If true, normalize " - "variance to one."); // naming this as in apply-cmvn.cc - opts->Register("center", ¢er, "If true, use a window centered on the " - "current frame (to the extent possible, modulo end effects). " - "If false, window is to the left."); - } - void Check() const; -}; - - -/// Applies sliding-window cepstral mean and/or variance normalization. See the -/// strings registering the options in the options class for information on how -/// this works and what the options are. input and output must have the same -/// dimension. -void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, - const MatrixBase &input, - MatrixBase *output); - - -/// @} End of "addtogroup feat" -} // namespace kaldi - - - -#endif // KALDI_FEAT_FEATURE_FUNCTIONS_H_ diff --git a/speechx/speechx/kaldi/feat/feature-mfcc.cc b/speechx/speechx/kaldi/feat/feature-mfcc.cc deleted file mode 100644 index 73ab4b31..00000000 --- a/speechx/speechx/kaldi/feat/feature-mfcc.cc +++ /dev/null @@ -1,157 +0,0 @@ -// feat/feature-mfcc.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-mfcc.h" - - -namespace kaldi { - - -void MfccComputer::Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && - feature->Dim() == this->Dim()); - - const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); - - if (opts_.use_energy && !opts_.raw_energy) - signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::epsilon())); - - if (srfft_ != NULL) // Compute FFT using the split-radix algorithm. - srfft_->Compute(signal_frame->Data(), true); - else // An alternative algorithm that works for non-powers-of-two. - RealFft(signal_frame, true); - - // Convert the FFT into a power spectrum. - ComputePowerSpectrum(signal_frame); - SubVector power_spectrum(*signal_frame, 0, - signal_frame->Dim() / 2 + 1); - - mel_banks.Compute(power_spectrum, &mel_energies_); - - // avoid log of zero (which should be prevented anyway by dithering). - mel_energies_.ApplyFloor(std::numeric_limits::epsilon()); - mel_energies_.ApplyLog(); // take the log. - - feature->SetZero(); // in case there were NaNs. - // feature = dct_matrix_ * mel_energies [which now have log] - feature->AddMatVec(1.0, dct_matrix_, kNoTrans, mel_energies_, 0.0); - - if (opts_.cepstral_lifter != 0.0) - feature->MulElements(lifter_coeffs_); - - if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) - signal_raw_log_energy = log_energy_floor_; - (*feature)(0) = signal_raw_log_energy; - } - - if (opts_.htk_compat) { - BaseFloat energy = (*feature)(0); - for (int32 i = 0; i < opts_.num_ceps - 1; i++) - (*feature)(i) = (*feature)(i+1); - if (!opts_.use_energy) - energy *= M_SQRT2; // scale on C0 (actually removing a scale - // we previously added that's part of one common definition of - // the cosine transform.) - (*feature)(opts_.num_ceps - 1) = energy; - } -} - -MfccComputer::MfccComputer(const MfccOptions &opts): - opts_(opts), srfft_(NULL), - mel_energies_(opts.mel_opts.num_bins) { - - int32 num_bins = opts.mel_opts.num_bins; - if (opts.num_ceps > num_bins) - KALDI_ERR << "num-ceps cannot be larger than num-mel-bins." - << " It should be smaller or equal. You provided num-ceps: " - << opts.num_ceps << " and num-mel-bins: " - << num_bins; - - Matrix dct_matrix(num_bins, num_bins); - ComputeDctMatrix(&dct_matrix); - // Note that we include zeroth dct in either case. If using the - // energy we replace this with the energy. This means a different - // ordering of features than HTK. - SubMatrix dct_rows(dct_matrix, 0, opts.num_ceps, 0, num_bins); - dct_matrix_.Resize(opts.num_ceps, num_bins); - dct_matrix_.CopyFromMat(dct_rows); // subset of rows. - if (opts.cepstral_lifter != 0.0) { - lifter_coeffs_.Resize(opts.num_ceps); - ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); - } - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); - - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); - - // We'll definitely need the filterbanks info for VTLN warping factor 1.0. - // [note: this call caches it.] - GetMelBanks(1.0); -} - -MfccComputer::MfccComputer(const MfccComputer &other): - opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), - dct_matrix_(other.dct_matrix_), - log_energy_floor_(other.log_energy_floor_), - mel_banks_(other.mel_banks_), - srfft_(NULL), - mel_energies_(other.mel_energies_.Dim(), kUndefined) { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - iter->second = new MelBanks(*(iter->second)); - if (other.srfft_ != NULL) - srfft_ = new SplitRadixRealFft(*(other.srfft_)); -} - - - -MfccComputer::~MfccComputer() { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); - ++iter) - delete iter->second; - delete srfft_; -} - -const MelBanks *MfccComputer::GetMelBanks(BaseFloat vtln_warp) { - MelBanks *this_mel_banks = NULL; - std::map::iterator iter = mel_banks_.find(vtln_warp); - if (iter == mel_banks_.end()) { - this_mel_banks = new MelBanks(opts_.mel_opts, - opts_.frame_opts, - vtln_warp); - mel_banks_[vtln_warp] = this_mel_banks; - } else { - this_mel_banks = iter->second; - } - return this_mel_banks; -} - - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-mfcc.h b/speechx/speechx/kaldi/feat/feature-mfcc.h deleted file mode 100644 index dbfb9d60..00000000 --- a/speechx/speechx/kaldi/feat/feature-mfcc.h +++ /dev/null @@ -1,154 +0,0 @@ -// feat/feature-mfcc.h - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University -// 2014-2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_MFCC_H_ -#define KALDI_FEAT_FEATURE_MFCC_H_ - -#include -#include - -#include "feat/feature-common.h" -#include "feat/feature-functions.h" -#include "feat/feature-window.h" -#include "feat/mel-computations.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - -/// MfccOptions contains basic options for computing MFCC features. -struct MfccOptions { - FrameExtractionOptions frame_opts; - MelBanksOptions mel_opts; - int32 num_ceps; // e.g. 13: num cepstral coeffs, counting zero. - bool use_energy; // use energy; else C0 - BaseFloat energy_floor; // 0 by default; set to a value like 1.0 or 0.1 if - // you disable dithering. - bool raw_energy; // If true, compute energy before preemphasis and windowing - BaseFloat cepstral_lifter; // Scaling factor on cepstra for HTK compatibility. - // if 0.0, no liftering is done. - bool htk_compat; // if true, put energy/C0 last and introduce a factor of - // sqrt(2) on C0 to be the same as HTK. - - MfccOptions() : mel_opts(23), - // defaults the #mel-banks to 23 for the MFCC computations. - // this seems to be common for 16khz-sampled data, - // but for 8khz-sampled data, 15 may be better. - num_ceps(13), - use_energy(true), - energy_floor(0.0), - raw_energy(true), - cepstral_lifter(22.0), - htk_compat(false) {} - - void Register(OptionsItf *opts) { - frame_opts.Register(opts); - mel_opts.Register(opts); - opts->Register("num-ceps", &num_ceps, - "Number of cepstra in MFCC computation (including C0)"); - opts->Register("use-energy", &use_energy, - "Use energy (not C0) in MFCC computation"); - opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in MFCC computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("cepstral-lifter", &cepstral_lifter, - "Constant that controls scaling of MFCCs"); - opts->Register("htk-compat", &htk_compat, - "If true, put energy or C0 last and use a factor of sqrt(2) on " - "C0. Warning: not sufficient to get HTK compatible features " - "(need to change other parameters)."); - } -}; - - - -// This is the new-style interface to the MFCC computation. -class MfccComputer { - public: - typedef MfccOptions Options; - explicit MfccComputer(const MfccOptions &opts); - MfccComputer(const MfccComputer &other); - - const FrameExtractionOptions &GetFrameOptions() const { - return opts_.frame_opts; - } - - int32 Dim() const { return opts_.num_ceps; } - - bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } - - /** - Function that computes one frame of features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). - @param [in] vtln_warp The VTLN warping factor that the user wants - to be applied when computing features for this utterance. Will - normally be 1.0, meaning no warping is to be done. The value will - be ignored for feature types that don't support VLTN, such as - spectrogram features. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - ~MfccComputer(); - private: - // disallow assignment. - MfccComputer &operator = (const MfccComputer &in); - - protected: - const MelBanks *GetMelBanks(BaseFloat vtln_warp); - - MfccOptions opts_; - Vector lifter_coeffs_; - Matrix dct_matrix_; // matrix we left-multiply by to perform DCT. - BaseFloat log_energy_floor_; - std::map mel_banks_; // BaseFloat is VTLN coefficient. - SplitRadixRealFft *srfft_; - - // note: mel_energies_ is specific to the frame we're processing, it's - // just a temporary workspace. - Vector mel_energies_; -}; - -typedef OfflineFeatureTpl Mfcc; - - -/// @} End of "addtogroup feat" -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_MFCC_H_ diff --git a/speechx/speechx/kaldi/feat/feature-plp.cc b/speechx/speechx/kaldi/feat/feature-plp.cc deleted file mode 100644 index e0c270c7..00000000 --- a/speechx/speechx/kaldi/feat/feature-plp.cc +++ /dev/null @@ -1,191 +0,0 @@ -// feat/feature-plp.cc - -// Copyright 2009-2011 Petr Motlicek; Karel Vesely -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-plp.h" - -namespace kaldi { - -PlpComputer::PlpComputer(const PlpOptions &opts): - opts_(opts), srfft_(NULL), - mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), - autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), - lpc_coeffs_(opts_.lpc_order, kUndefined), - raw_cepstrum_(opts_.lpc_order, kUndefined) { - - if (opts.cepstral_lifter != 0.0) { - lifter_coeffs_.Resize(opts.num_ceps); - ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); - } - InitIdftBases(opts_.lpc_order + 1, opts_.mel_opts.num_bins + 2, - &idft_bases_); - - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); - - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... - srfft_ = new SplitRadixRealFft(padded_window_size); - - // We'll definitely need the filterbanks info for VTLN warping factor 1.0. - // [note: this call caches it.] - GetMelBanks(1.0); -} - -PlpComputer::PlpComputer(const PlpComputer &other): - opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), - idft_bases_(other.idft_bases_), log_energy_floor_(other.log_energy_floor_), - mel_banks_(other.mel_banks_), equal_loudness_(other.equal_loudness_), - srfft_(NULL), - mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), - autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), - lpc_coeffs_(opts_.lpc_order, kUndefined), - raw_cepstrum_(opts_.lpc_order, kUndefined) { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - iter->second = new MelBanks(*(iter->second)); - for (std::map*>::iterator - iter = equal_loudness_.begin(); - iter != equal_loudness_.end(); ++iter) - iter->second = new Vector(*(iter->second)); - if (other.srfft_ != NULL) - srfft_ = new SplitRadixRealFft(*(other.srfft_)); -} - -PlpComputer::~PlpComputer() { - for (std::map::iterator iter = mel_banks_.begin(); - iter != mel_banks_.end(); ++iter) - delete iter->second; - for (std::map* >::iterator - iter = equal_loudness_.begin(); - iter != equal_loudness_.end(); ++iter) - delete iter->second; - delete srfft_; -} - -const MelBanks *PlpComputer::GetMelBanks(BaseFloat vtln_warp) { - MelBanks *this_mel_banks = NULL; - std::map::iterator iter = mel_banks_.find(vtln_warp); - if (iter == mel_banks_.end()) { - this_mel_banks = new MelBanks(opts_.mel_opts, - opts_.frame_opts, - vtln_warp); - mel_banks_[vtln_warp] = this_mel_banks; - } else { - this_mel_banks = iter->second; - } - return this_mel_banks; -} - -const Vector *PlpComputer::GetEqualLoudness(BaseFloat vtln_warp) { - const MelBanks *this_mel_banks = GetMelBanks(vtln_warp); - Vector *ans = NULL; - std::map*>::iterator iter - = equal_loudness_.find(vtln_warp); - if (iter == equal_loudness_.end()) { - ans = new Vector; - GetEqualLoudnessVector(*this_mel_banks, ans); - equal_loudness_[vtln_warp] = ans; - } else { - ans = iter->second; - } - return ans; -} - -void PlpComputer::Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && - feature->Dim() == this->Dim()); - - const MelBanks &mel_banks = *GetMelBanks(vtln_warp); - const Vector &equal_loudness = *GetEqualLoudness(vtln_warp); - - - KALDI_ASSERT(opts_.num_ceps <= opts_.lpc_order+1); // our num-ceps includes C0. - - - if (opts_.use_energy && !opts_.raw_energy) - signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::min())); - - if (srfft_ != NULL) // Compute FFT using split-radix algorithm. - srfft_->Compute(signal_frame->Data(), true); - else // An alternative algorithm that works for non-powers-of-two. - RealFft(signal_frame, true); - - // Convert the FFT into a power spectrum. - ComputePowerSpectrum(signal_frame); // elements 0 ... signal_frame->Dim()/2 - - SubVector power_spectrum(*signal_frame, - 0, signal_frame->Dim() / 2 + 1); - - int32 num_mel_bins = opts_.mel_opts.num_bins; - - SubVector mel_energies(mel_energies_duplicated_, 1, num_mel_bins); - - mel_banks.Compute(power_spectrum, &mel_energies); - - mel_energies.MulElements(equal_loudness); - - mel_energies.ApplyPow(opts_.compress_factor); - - // duplicate first and last elements - mel_energies_duplicated_(0) = mel_energies_duplicated_(1); - mel_energies_duplicated_(num_mel_bins + 1) = - mel_energies_duplicated_(num_mel_bins); - - autocorr_coeffs_.SetZero(); // In case of NaNs or infs - autocorr_coeffs_.AddMatVec(1.0, idft_bases_, kNoTrans, - mel_energies_duplicated_, 0.0); - - BaseFloat residual_log_energy = ComputeLpc(autocorr_coeffs_, &lpc_coeffs_); - - residual_log_energy = std::max(residual_log_energy, - std::numeric_limits::min()); - - Lpc2Cepstrum(opts_.lpc_order, lpc_coeffs_.Data(), raw_cepstrum_.Data()); - feature->Range(1, opts_.num_ceps - 1).CopyFromVec( - raw_cepstrum_.Range(0, opts_.num_ceps - 1)); - (*feature)(0) = residual_log_energy; - - if (opts_.cepstral_lifter != 0.0) - feature->MulElements(lifter_coeffs_); - - if (opts_.cepstral_scale != 1.0) - feature->Scale(opts_.cepstral_scale); - - if (opts_.use_energy) { - if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) - signal_raw_log_energy = log_energy_floor_; - (*feature)(0) = signal_raw_log_energy; - } - - if (opts_.htk_compat) { // reorder the features. - BaseFloat log_energy = (*feature)(0); - for (int32 i = 0; i < opts_.num_ceps-1; i++) - (*feature)(i) = (*feature)(i+1); - (*feature)(opts_.num_ceps-1) = log_energy; - } -} - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-plp.h b/speechx/speechx/kaldi/feat/feature-plp.h deleted file mode 100644 index cce6ee1c..00000000 --- a/speechx/speechx/kaldi/feat/feature-plp.h +++ /dev/null @@ -1,176 +0,0 @@ -// feat/feature-plp.h - -// Copyright 2009-2011 Petr Motlicek; Karel Vesely - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_PLP_H_ -#define KALDI_FEAT_FEATURE_PLP_H_ - -#include -#include - -#include "feat/feature-common.h" -#include "feat/feature-functions.h" -#include "feat/feature-window.h" -#include "feat/mel-computations.h" -#include "util/options-itf.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - - -/// PlpOptions contains basic options for computing PLP features. -/// It only includes things that can be done in a "stateless" way, i.e. -/// it does not include energy max-normalization. -/// It does not include delta computation. -struct PlpOptions { - FrameExtractionOptions frame_opts; - MelBanksOptions mel_opts; - int32 lpc_order; - int32 num_ceps; // num cepstra including zero - bool use_energy; // use energy; else C0 - BaseFloat energy_floor; - bool raw_energy; // If true, compute energy before preemphasis and windowing - BaseFloat compress_factor; - int32 cepstral_lifter; - BaseFloat cepstral_scale; - - bool htk_compat; // if true, put energy/C0 last and introduce a factor of - // sqrt(2) on C0 to be the same as HTK. - - PlpOptions() : mel_opts(23), - // default number of mel-banks for the PLP computation; this - // seems to be common for 16kHz-sampled data. For 8kHz-sampled - // data, 15 may be better. - lpc_order(12), - num_ceps(13), - use_energy(true), - energy_floor(0.0), - raw_energy(true), - compress_factor(0.33333), - cepstral_lifter(22), - cepstral_scale(1.0), - htk_compat(false) {} - - void Register(OptionsItf *opts) { - frame_opts.Register(opts); - mel_opts.Register(opts); - opts->Register("lpc-order", &lpc_order, - "Order of LPC analysis in PLP computation"); - opts->Register("num-ceps", &num_ceps, - "Number of cepstra in PLP computation (including C0)"); - opts->Register("use-energy", &use_energy, - "Use energy (not C0) for zeroth PLP feature"); - opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in PLP computation. " - "Only makes a difference if --use-energy=true; only necessary if " - "--dither=0.0. Suggested values: 0.1 or 1.0"); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - opts->Register("compress-factor", &compress_factor, - "Compression factor in PLP computation"); - opts->Register("cepstral-lifter", &cepstral_lifter, - "Constant that controls scaling of PLPs"); - opts->Register("cepstral-scale", &cepstral_scale, - "Scaling constant in PLP computation"); - opts->Register("htk-compat", &htk_compat, - "If true, put energy or C0 last. Warning: not sufficient " - "to get HTK compatible features (need to change other " - "parameters)."); - } -}; - - -/// This is the new-style interface to the PLP computation. -class PlpComputer { - public: - typedef PlpOptions Options; - explicit PlpComputer(const PlpOptions &opts); - PlpComputer(const PlpComputer &other); - - const FrameExtractionOptions &GetFrameOptions() const { - return opts_.frame_opts; - } - - int32 Dim() const { return opts_.num_ceps; } - - bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } - - /** - Function that computes one frame of features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). - @param [in] vtln_warp The VTLN warping factor that the user wants - to be applied when computing features for this utterance. Will - normally be 1.0, meaning no warping is to be done. The value will - be ignored for feature types that don't support VLTN, such as - spectrogram features. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - ~PlpComputer(); - private: - - const MelBanks *GetMelBanks(BaseFloat vtln_warp); - - const Vector *GetEqualLoudness(BaseFloat vtln_warp); - - PlpOptions opts_; - Vector lifter_coeffs_; - Matrix idft_bases_; - BaseFloat log_energy_floor_; - std::map mel_banks_; // BaseFloat is VTLN coefficient. - std::map* > equal_loudness_; - SplitRadixRealFft *srfft_; - - // temporary vector used inside Compute; size is opts_.mel_opts.num_bins + 2 - Vector mel_energies_duplicated_; - // temporary vector used inside Compute; size is opts_.lpc_order + 1 - Vector autocorr_coeffs_; - // temporary vector used inside Compute; size is opts_.lpc_order - Vector lpc_coeffs_; - // temporary vector used inside Compute; size is opts_.lpc_order - Vector raw_cepstrum_; - - // Disallow assignment. - PlpComputer &operator =(const PlpComputer &other); -}; - -typedef OfflineFeatureTpl Plp; - -/// @} End of "addtogroup feat" - -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_PLP_H_ diff --git a/speechx/speechx/kaldi/feat/feature-spectrogram.cc b/speechx/speechx/kaldi/feat/feature-spectrogram.cc deleted file mode 100644 index 7eee2643..00000000 --- a/speechx/speechx/kaldi/feat/feature-spectrogram.cc +++ /dev/null @@ -1,82 +0,0 @@ -// feat/feature-spectrogram.cc - -// Copyright 2009-2012 Karel Vesely -// Copyright 2012 Navdeep Jaitly - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-spectrogram.h" - - -namespace kaldi { - -SpectrogramComputer::SpectrogramComputer(const SpectrogramOptions &opts) - : opts_(opts), srfft_(NULL) { - if (opts.energy_floor > 0.0) - log_energy_floor_ = Log(opts.energy_floor); - - int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); - if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two - srfft_ = new SplitRadixRealFft(padded_window_size); -} - -SpectrogramComputer::SpectrogramComputer(const SpectrogramComputer &other): - opts_(other.opts_), log_energy_floor_(other.log_energy_floor_), srfft_(NULL) { - if (other.srfft_ != NULL) - srfft_ = new SplitRadixRealFft(*other.srfft_); -} - -SpectrogramComputer::~SpectrogramComputer() { - delete srfft_; -} - -void SpectrogramComputer::Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature) { - KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && - feature->Dim() == this->Dim()); - - - // Compute energy after window function (not the raw one) - if (!opts_.raw_energy) - signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), - std::numeric_limits::epsilon())); - - if (srfft_ != NULL) // Compute FFT using split-radix algorithm. - srfft_->Compute(signal_frame->Data(), true); - else // An alternative algorithm that works for non-powers-of-two - RealFft(signal_frame, true); - - // Convert the FFT into a power spectrum. - ComputePowerSpectrum(signal_frame); - SubVector power_spectrum(*signal_frame, - 0, signal_frame->Dim() / 2 + 1); - - power_spectrum.ApplyFloor(std::numeric_limits::epsilon()); - power_spectrum.ApplyLog(); - - feature->CopyFromVec(power_spectrum); - - if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) - signal_raw_log_energy = log_energy_floor_; - // The zeroth spectrogram component is always set to the signal energy, - // instead of the square of the constant component of the signal. - (*feature)(0) = signal_raw_log_energy; -} - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-spectrogram.h b/speechx/speechx/kaldi/feat/feature-spectrogram.h deleted file mode 100644 index 132a6875..00000000 --- a/speechx/speechx/kaldi/feat/feature-spectrogram.h +++ /dev/null @@ -1,117 +0,0 @@ -// feat/feature-spectrogram.h - -// Copyright 2009-2012 Karel Vesely -// Copyright 2012 Navdeep Jaitly - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_SPECTROGRAM_H_ -#define KALDI_FEAT_FEATURE_SPECTROGRAM_H_ - - -#include - -#include "feat/feature-common.h" -#include "feat/feature-functions.h" -#include "feat/feature-window.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - - -/// SpectrogramOptions contains basic options for computing spectrogram -/// features. -struct SpectrogramOptions { - FrameExtractionOptions frame_opts; - BaseFloat energy_floor; - bool raw_energy; // If true, compute energy before preemphasis and windowing - - SpectrogramOptions() : - energy_floor(0.0), - raw_energy(true) {} - - void Register(OptionsItf *opts) { - frame_opts.Register(opts); - opts->Register("energy-floor", &energy_floor, - "Floor on energy (absolute, not relative) in Spectrogram " - "computation. Caution: this floor is applied to the zeroth " - "component, representing the total signal energy. The " - "floor on the individual spectrogram elements is fixed at " - "std::numeric_limits::epsilon()."); - opts->Register("raw-energy", &raw_energy, - "If true, compute energy before preemphasis and windowing"); - } -}; - -/// Class for computing spectrogram features. -class SpectrogramComputer { - public: - typedef SpectrogramOptions Options; - explicit SpectrogramComputer(const SpectrogramOptions &opts); - SpectrogramComputer(const SpectrogramComputer &other); - - const FrameExtractionOptions& GetFrameOptions() const { - return opts_.frame_opts; - } - - int32 Dim() const { return opts_.frame_opts.PaddedWindowSize() / 2 + 1; } - - bool NeedRawLogEnergy() const { return opts_.raw_energy; } - - - /** - Function that computes one frame of spectrogram features from - one frame of signal. - - @param [in] signal_raw_log_energy The log-energy of the frame of the signal - prior to windowing and pre-emphasis, or - log(numeric_limits::min()), whichever is greater. Must be - ignored by this function if this class returns false from - this->NeedsRawLogEnergy(). - @param [in] vtln_warp This is ignored by this function, it's only - needed for interface compatibility. - @param [in] signal_frame One frame of the signal, - as extracted using the function ExtractWindow() using the options - returned by this->GetFrameOptions(). The function will use the - vector as a workspace, which is why it's a non-const pointer. - @param [out] feature Pointer to a vector of size this->Dim(), to which - the computed feature will be written. - */ - void Compute(BaseFloat signal_raw_log_energy, - BaseFloat vtln_warp, - VectorBase *signal_frame, - VectorBase *feature); - - ~SpectrogramComputer(); - - private: - SpectrogramOptions opts_; - BaseFloat log_energy_floor_; - SplitRadixRealFft *srfft_; - - // Disallow assignment. - SpectrogramComputer &operator=(const SpectrogramComputer &other); -}; - -typedef OfflineFeatureTpl Spectrogram; - - -/// @} End of "addtogroup feat" -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_SPECTROGRAM_H_ diff --git a/speechx/speechx/kaldi/feat/feature-window.cc b/speechx/speechx/kaldi/feat/feature-window.cc deleted file mode 100644 index c5d4cc29..00000000 --- a/speechx/speechx/kaldi/feat/feature-window.cc +++ /dev/null @@ -1,222 +0,0 @@ -// feat/feature-window.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation -// 2013-2016 Johns Hopkins University (author: Daniel Povey) -// 2014 IMSL, PKU-HKUST (author: Wei Shi) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include "feat/feature-window.h" -#include "matrix/matrix-functions.h" - - -namespace kaldi { - - -int64 FirstSampleOfFrame(int32 frame, - const FrameExtractionOptions &opts) { - int64 frame_shift = opts.WindowShift(); - if (opts.snip_edges) { - return frame * frame_shift; - } else { - int64 midpoint_of_frame = frame_shift * frame + frame_shift / 2, - beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; - return beginning_of_frame; - } -} - -int32 NumFrames(int64 num_samples, - const FrameExtractionOptions &opts, - bool flush) { - int64 frame_shift = opts.WindowShift(); - int64 frame_length = opts.WindowSize(); - if (opts.snip_edges) { - // with --snip-edges=true (the default), we use a HTK-like approach to - // determining the number of frames-- all frames have to fit completely into - // the waveform, and the first frame begins at sample zero. - if (num_samples < frame_length) - return 0; - else - return (1 + ((num_samples - frame_length) / frame_shift)); - // You can understand the expression above as follows: 'num_samples - - // frame_length' is how much room we have to shift the frame within the - // waveform; 'frame_shift' is how much we shift it each time; and the ratio - // is how many times we can shift it (integer arithmetic rounds down). - } else { - // if --snip-edges=false, the number of frames is determined by rounding the - // (file-length / frame-shift) to the nearest integer. The point of this - // formula is to make the number of frames an obvious and predictable - // function of the frame shift and signal length, which makes many - // segmentation-related questions simpler. - // - // Because integer division in C++ rounds toward zero, we add (half the - // frame-shift minus epsilon) before dividing, to have the effect of - // rounding towards the closest integer. - int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift; - - if (flush) - return num_frames; - - // note: 'end' always means the last plus one, i.e. one past the last. - int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts) - + frame_length; - - // the following code is optimized more for clarity than efficiency. - // If flush == false, we can't output frames that extend past the end - // of the signal. - while (num_frames > 0 && end_sample_of_last_frame > num_samples) { - num_frames--; - end_sample_of_last_frame -= frame_shift; - } - return num_frames; - } -} - - -void Dither(VectorBase *waveform, BaseFloat dither_value) { - if (dither_value == 0.0) - return; - int32 dim = waveform->Dim(); - BaseFloat *data = waveform->Data(); - RandomState rstate; - for (int32 i = 0; i < dim; i++) - data[i] += RandGauss(&rstate) * dither_value; -} - - -void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff) { - if (preemph_coeff == 0.0) return; - KALDI_ASSERT(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); - for (int32 i = waveform->Dim()-1; i > 0; i--) - (*waveform)(i) -= preemph_coeff * (*waveform)(i-1); - (*waveform)(0) -= preemph_coeff * (*waveform)(0); -} - -FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) { - int32 frame_length = opts.WindowSize(); - KALDI_ASSERT(frame_length > 0); - window.Resize(frame_length); - double a = M_2PI / (frame_length-1); - for (int32 i = 0; i < frame_length; i++) { - double i_fl = static_cast(i); - if (opts.window_type == "hanning") { - window(i) = 0.5 - 0.5*cos(a * i_fl); - } else if (opts.window_type == "hamming") { - window(i) = 0.54 - 0.46*cos(a * i_fl); - } else if (opts.window_type == "povey") { // like hamming but goes to zero at edges. - window(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85); - } else if (opts.window_type == "rectangular") { - window(i) = 1.0; - } else if (opts.window_type == "blackman") { - window(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) + - (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); - } else { - KALDI_ERR << "Invalid window type " << opts.window_type; - } - } -} - -void ProcessWindow(const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - VectorBase *window, - BaseFloat *log_energy_pre_window) { - int32 frame_length = opts.WindowSize(); - KALDI_ASSERT(window->Dim() == frame_length); - - if (opts.dither != 0.0) - Dither(window, opts.dither); - - if (opts.remove_dc_offset) - window->Add(-window->Sum() / frame_length); - - if (log_energy_pre_window != NULL) { - BaseFloat energy = std::max(VecVec(*window, *window), - std::numeric_limits::epsilon()); - *log_energy_pre_window = Log(energy); - } - - if (opts.preemph_coeff != 0.0) - Preemphasize(window, opts.preemph_coeff); - - window->MulElements(window_function.window); -} - - -// ExtractWindow extracts a windowed frame of waveform with a power-of-two, -// padded size. It does mean subtraction, pre-emphasis and dithering as -// requested. -void ExtractWindow(int64 sample_offset, - const VectorBase &wave, - int32 f, // with 0 <= f < NumFrames(feats, opts) - const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - Vector *window, - BaseFloat *log_energy_pre_window) { - KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0); - int32 frame_length = opts.WindowSize(), - frame_length_padded = opts.PaddedWindowSize(); - int64 num_samples = sample_offset + wave.Dim(), - start_sample = FirstSampleOfFrame(f, opts), - end_sample = start_sample + frame_length; - - if (opts.snip_edges) { - KALDI_ASSERT(start_sample >= sample_offset && - end_sample <= num_samples); - } else { - KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); - } - - if (window->Dim() != frame_length_padded) - window->Resize(frame_length_padded, kUndefined); - - // wave_start and wave_end are start and end indexes into 'wave', for the - // piece of wave that we're trying to extract. - int32 wave_start = int32(start_sample - sample_offset), - wave_end = wave_start + frame_length; - if (wave_start >= 0 && wave_end <= wave.Dim()) { - // the normal case-- no edge effects to consider. - window->Range(0, frame_length).CopyFromVec( - wave.Range(wave_start, frame_length)); - } else { - // Deal with any end effects by reflection, if needed. This code will only - // be reached for about two frames per utterance, so we don't concern - // ourselves excessively with efficiency. - int32 wave_dim = wave.Dim(); - for (int32 s = 0; s < frame_length; s++) { - int32 s_in_wave = s + wave_start; - while (s_in_wave < 0 || s_in_wave >= wave_dim) { - // reflect around the beginning or end of the wave. - // e.g. -1 -> 0, -2 -> 1. - // dim -> dim - 1, dim + 1 -> dim - 2. - // the code supports repeated reflections, although this - // would only be needed in pathological cases. - if (s_in_wave < 0) s_in_wave = - s_in_wave - 1; - else s_in_wave = 2 * wave_dim - 1 - s_in_wave; - } - (*window)(s) = wave(s_in_wave); - } - } - - if (frame_length_padded > frame_length) - window->Range(frame_length, frame_length_padded - frame_length).SetZero(); - - SubVector frame(*window, 0, frame_length); - - ProcessWindow(opts, window_function, &frame, log_energy_pre_window); -} - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-window.h b/speechx/speechx/kaldi/feat/feature-window.h deleted file mode 100644 index a7abba50..00000000 --- a/speechx/speechx/kaldi/feat/feature-window.h +++ /dev/null @@ -1,223 +0,0 @@ -// feat/feature-window.h - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University -// 2014-2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_FEATURE_WINDOW_H_ -#define KALDI_FEAT_FEATURE_WINDOW_H_ - -#include -#include - -#include "matrix/matrix-lib.h" -#include "util/common-utils.h" -#include "base/kaldi-error.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - -struct FrameExtractionOptions { - BaseFloat samp_freq; - BaseFloat frame_shift_ms; // in milliseconds. - BaseFloat frame_length_ms; // in milliseconds. - BaseFloat dither; // Amount of dithering, 0.0 means no dither. - BaseFloat preemph_coeff; // Preemphasis coefficient. - bool remove_dc_offset; // Subtract mean of wave before FFT. - std::string window_type; // e.g. Hamming window - // May be "hamming", "rectangular", "povey", "hanning", "blackman" - // "povey" is a window I made to be similar to Hamming but to go to zero at the - // edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) - // I just don't think the Hamming window makes sense as a windowing function. - bool round_to_power_of_two; - BaseFloat blackman_coeff; - bool snip_edges; - bool allow_downsample; - bool allow_upsample; - int max_feature_vectors; - FrameExtractionOptions(): - samp_freq(16000), - frame_shift_ms(10.0), - frame_length_ms(25.0), - dither(1.0), - preemph_coeff(0.97), - remove_dc_offset(true), - window_type("povey"), - round_to_power_of_two(true), - blackman_coeff(0.42), - snip_edges(true), - allow_downsample(false), - allow_upsample(false), - max_feature_vectors(-1) - { } - - void Register(OptionsItf *opts) { - opts->Register("sample-frequency", &samp_freq, - "Waveform data sample frequency (must match the waveform file, " - "if specified there)"); - opts->Register("frame-length", &frame_length_ms, "Frame length in milliseconds"); - opts->Register("frame-shift", &frame_shift_ms, "Frame shift in milliseconds"); - opts->Register("preemphasis-coefficient", &preemph_coeff, - "Coefficient for use in signal preemphasis"); - opts->Register("remove-dc-offset", &remove_dc_offset, - "Subtract mean from waveform on each frame"); - opts->Register("dither", &dither, "Dithering constant (0.0 means no dither). " - "If you turn this off, you should set the --energy-floor " - "option, e.g. to 1.0 or 0.1"); - opts->Register("window-type", &window_type, "Type of window " - "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" - "|\"blackmann\")"); - opts->Register("blackman-coeff", &blackman_coeff, - "Constant coefficient for generalized Blackman window."); - opts->Register("round-to-power-of-two", &round_to_power_of_two, - "If true, round window size to power of two by zero-padding " - "input to FFT."); - opts->Register("snip-edges", &snip_edges, - "If true, end effects will be handled by outputting only frames that " - "completely fit in the file, and the number of frames depends on the " - "frame-length. If false, the number of frames depends only on the " - "frame-shift, and we reflect the data at the ends."); - opts->Register("allow-downsample", &allow_downsample, - "If true, allow the input waveform to have a higher frequency than " - "the specified --sample-frequency (and we'll downsample)."); - opts->Register("max-feature-vectors", &max_feature_vectors, - "Memory optimization. If larger than 0, periodically remove feature " - "vectors so that only this number of the latest feature vectors is " - "retained."); - opts->Register("allow-upsample", &allow_upsample, - "If true, allow the input waveform to have a lower frequency than " - "the specified --sample-frequency (and we'll upsample)."); - } - int32 WindowShift() const { - return static_cast(samp_freq * 0.001 * frame_shift_ms); - } - int32 WindowSize() const { - return static_cast(samp_freq * 0.001 * frame_length_ms); - } - int32 PaddedWindowSize() const { - return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : - WindowSize()); - } -}; - - -struct FeatureWindowFunction { - FeatureWindowFunction() {} - explicit FeatureWindowFunction(const FrameExtractionOptions &opts); - FeatureWindowFunction(const FeatureWindowFunction &other): - window(other.window) { } - Vector window; -}; - - -/** - This function returns the number of frames that we can extract from a wave - file with the given number of samples in it (assumed to have the same - sampling rate as specified in 'opts'). - - @param [in] num_samples The number of samples in the wave file. - @param [in] opts The frame-extraction options class - - @param [in] flush True if we are asserting that this number of samples is - 'all there is', false if we expecting more data to possibly come - in. This only makes a difference to the answer if opts.snips_edges - == false. For offline feature extraction you always want flush == - true. In an online-decoding context, once you know (or decide) that - no more data is coming in, you'd call it with flush == true at the - end to flush out any remaining data. -*/ -int32 NumFrames(int64 num_samples, - const FrameExtractionOptions &opts, - bool flush = true); - -/* - This function returns the index of the first sample of the frame indexed - 'frame'. If snip-edges=true, it just returns frame * opts.WindowShift(); if - snip-edges=false, the formula is a little more complicated and the result may - be negative. -*/ -int64 FirstSampleOfFrame(int32 frame, - const FrameExtractionOptions &opts); - - - -void Dither(VectorBase *waveform, BaseFloat dither_value); - -void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff); - -/** - This function does all the windowing steps after actually - extracting the windowed signal: depending on the - configuration, it does dithering, dc offset removal, - preemphasis, and multiplication by the windowing function. - @param [in] opts The options class to be used - @param [in] window_function The windowing function-- should have - been initialized using 'opts'. - @param [in,out] window A vector of size opts.WindowSize(). Note: - it will typically be a sub-vector of a larger vector of size - opts.PaddedWindowSize(), with the remaining samples zero, - as the FFT code is more efficient if it operates on data with - power-of-two size. - @param [out] log_energy_pre_window If non-NULL, then after dithering and - DC offset removal, this function will write to this pointer the log of - the total energy (i.e. sum-squared) of the frame. - */ -void ProcessWindow(const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - VectorBase *window, - BaseFloat *log_energy_pre_window = NULL); - - -/* - ExtractWindow() extracts a windowed frame of waveform (possibly with a - power-of-two, padded size, depending on the config), including all the - proessing done by ProcessWindow(). - - @param [in] sample_offset If 'wave' is not the entire waveform, but - part of it to the left has been discarded, then the - number of samples prior to 'wave' that we have - already discarded. Set this to zero if you are - processing the entire waveform in one piece, or - if you get 'no matching function' compilation - errors when updating the code. - @param [in] wave The waveform - @param [in] f The frame index to be extracted, with - 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) - @param [in] opts The options class to be used - @param [in] window_function The windowing function, as derived from the - options class. - @param [out] window The windowed, possibly-padded waveform to be - extracted. Will be resized as needed. - @param [out] log_energy_pre_window If non-NULL, the log-energy of - the signal prior to pre-emphasis and multiplying by - the windowing function will be written to here. -*/ -void ExtractWindow(int64 sample_offset, - const VectorBase &wave, - int32 f, - const FrameExtractionOptions &opts, - const FeatureWindowFunction &window_function, - Vector *window, - BaseFloat *log_energy_pre_window = NULL); - - -/// @} End of "addtogroup feat" -} // namespace kaldi - - -#endif // KALDI_FEAT_FEATURE_WINDOW_H_ diff --git a/speechx/speechx/kaldi/feat/mel-computations.cc b/speechx/speechx/kaldi/feat/mel-computations.cc deleted file mode 100644 index 626cb677..00000000 --- a/speechx/speechx/kaldi/feat/mel-computations.cc +++ /dev/null @@ -1,340 +0,0 @@ -// feat/mel-computations.cc - -// Copyright 2009-2011 Phonexia s.r.o.; Karel Vesely; Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "feat/feature-functions.h" -#include "feat/feature-window.h" -#include "feat/mel-computations.h" - -namespace kaldi { - - -MelBanks::MelBanks(const MelBanksOptions &opts, - const FrameExtractionOptions &frame_opts, - BaseFloat vtln_warp_factor): - htk_mode_(opts.htk_mode) { - int32 num_bins = opts.num_bins; - if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; - BaseFloat sample_freq = frame_opts.samp_freq; - int32 window_length_padded = frame_opts.PaddedWindowSize(); - KALDI_ASSERT(window_length_padded % 2 == 0); - int32 num_fft_bins = window_length_padded / 2; - BaseFloat nyquist = 0.5 * sample_freq; - - BaseFloat low_freq = opts.low_freq, high_freq; - if (opts.high_freq > 0.0) - high_freq = opts.high_freq; - else - high_freq = nyquist + opts.high_freq; - - if (low_freq < 0.0 || low_freq >= nyquist - || high_freq <= 0.0 || high_freq > nyquist - || high_freq <= low_freq) - KALDI_ERR << "Bad values in options: low-freq " << low_freq - << " and high-freq " << high_freq << " vs. nyquist " - << nyquist; - - BaseFloat fft_bin_width = sample_freq / window_length_padded; - // fft-bin width [think of it as Nyquist-freq / half-window-length] - - BaseFloat mel_low_freq = MelScale(low_freq); - BaseFloat mel_high_freq = MelScale(high_freq); - - debug_ = opts.debug_mel; - - // divide by num_bins+1 in next line because of end-effects where the bins - // spread out to the sides. - BaseFloat mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins+1); - - BaseFloat vtln_low = opts.vtln_low, - vtln_high = opts.vtln_high; - if (vtln_high < 0.0) { - vtln_high += nyquist; - } - - if (vtln_warp_factor != 1.0 && - (vtln_low < 0.0 || vtln_low <= low_freq - || vtln_low >= high_freq - || vtln_high <= 0.0 || vtln_high >= high_freq - || vtln_high <= vtln_low)) - KALDI_ERR << "Bad values in options: vtln-low " << vtln_low - << " and vtln-high " << vtln_high << ", versus " - << "low-freq " << low_freq << " and high-freq " - << high_freq; - - bins_.resize(num_bins); - center_freqs_.Resize(num_bins); - - for (int32 bin = 0; bin < num_bins; bin++) { - BaseFloat left_mel = mel_low_freq + bin * mel_freq_delta, - center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, - right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; - - if (vtln_warp_factor != 1.0) { - left_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, - vtln_warp_factor, left_mel); - center_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, - vtln_warp_factor, center_mel); - right_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, - vtln_warp_factor, right_mel); - } - center_freqs_(bin) = InverseMelScale(center_mel); - // this_bin will be a vector of coefficients that is only - // nonzero where this mel bin is active. - Vector this_bin(num_fft_bins); - int32 first_index = -1, last_index = -1; - for (int32 i = 0; i < num_fft_bins; i++) { - BaseFloat freq = (fft_bin_width * i); // Center frequency of this fft - // bin. - BaseFloat mel = MelScale(freq); - if (mel > left_mel && mel < right_mel) { - BaseFloat weight; - if (mel <= center_mel) - weight = (mel - left_mel) / (center_mel - left_mel); - else - weight = (right_mel-mel) / (right_mel-center_mel); - this_bin(i) = weight; - if (first_index == -1) - first_index = i; - last_index = i; - } - } - //KALDI_ASSERT(first_index != -1 && last_index >= first_index - // && "You may have set --num-mel-bins too large."); - - bins_[bin].first = first_index; - int32 size = last_index + 1 - first_index; - bins_[bin].second.Resize(size); - bins_[bin].second.CopyFromVec(this_bin.Range(first_index, size)); - - // Replicate a bug in HTK, for testing purposes. - if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0) - bins_[bin].second(0) = 0.0; - - } - if (debug_) { - for (size_t i = 0; i < bins_.size(); i++) { - KALDI_LOG << "bin " << i << ", offset = " << bins_[i].first - << ", vec = " << bins_[i].second; - } - } -} - -MelBanks::MelBanks(const MelBanks &other): - center_freqs_(other.center_freqs_), - bins_(other.bins_), - debug_(other.debug_), - htk_mode_(other.htk_mode_) { } - -BaseFloat MelBanks::VtlnWarpFreq(BaseFloat vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. - BaseFloat vtln_high_cutoff, - BaseFloat low_freq, // upper+lower frequency cutoffs in mel computation - BaseFloat high_freq, - BaseFloat vtln_warp_factor, - BaseFloat freq) { - /// This computes a VTLN warping function that is not the same as HTK's one, - /// but has similar inputs (this function has the advantage of never producing - /// empty bins). - - /// This function computes a warp function F(freq), defined between low_freq and - /// high_freq inclusive, with the following properties: - /// F(low_freq) == low_freq - /// F(high_freq) == high_freq - /// The function is continuous and piecewise linear with two inflection - /// points. - /// The lower inflection point (measured in terms of the unwarped - /// frequency) is at frequency l, determined as described below. - /// The higher inflection point is at a frequency h, determined as - /// described below. - /// If l <= f <= h, then F(f) = f/vtln_warp_factor. - /// If the higher inflection point (measured in terms of the unwarped - /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. - /// Since (by the last point) F(h) == h/vtln_warp_factor, then - /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so - /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). - /// = vtln_high_cutoff * min(1, vtln_warp_factor). - /// If the lower inflection point (measured in terms of the unwarped - /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff - /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) - /// = vtln_low_cutoff * max(1, vtln_warp_factor) - - - if (freq < low_freq || freq > high_freq) return freq; // in case this gets called - // for out-of-range frequencies, just return the freq. - - KALDI_ASSERT(vtln_low_cutoff > low_freq && - "be sure to set the --vtln-low option higher than --low-freq"); - KALDI_ASSERT(vtln_high_cutoff < high_freq && - "be sure to set the --vtln-high option lower than --high-freq [or negative]"); - BaseFloat one = 1.0; - BaseFloat l = vtln_low_cutoff * std::max(one, vtln_warp_factor); - BaseFloat h = vtln_high_cutoff * std::min(one, vtln_warp_factor); - BaseFloat scale = 1.0 / vtln_warp_factor; - BaseFloat Fl = scale * l; // F(l); - BaseFloat Fh = scale * h; // F(h); - KALDI_ASSERT(l > low_freq && h < high_freq); - // slope of left part of the 3-piece linear function - BaseFloat scale_left = (Fl - low_freq) / (l - low_freq); - // [slope of center part is just "scale"] - - // slope of right part of the 3-piece linear function - BaseFloat scale_right = (high_freq - Fh) / (high_freq - h); - - if (freq < l) { - return low_freq + scale_left * (freq - low_freq); - } else if (freq < h) { - return scale * freq; - } else { // freq >= h - return high_freq + scale_right * (freq - high_freq); - } -} - -BaseFloat MelBanks::VtlnWarpMelFreq(BaseFloat vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. - BaseFloat vtln_high_cutoff, - BaseFloat low_freq, // upper+lower frequency cutoffs in mel computation - BaseFloat high_freq, - BaseFloat vtln_warp_factor, - BaseFloat mel_freq) { - return MelScale(VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, - low_freq, high_freq, - vtln_warp_factor, InverseMelScale(mel_freq))); -} - - -// "power_spectrum" contains fft energies. -void MelBanks::Compute(const VectorBase &power_spectrum, - VectorBase *mel_energies_out) const { - int32 num_bins = bins_.size(); - KALDI_ASSERT(mel_energies_out->Dim() == num_bins); - - for (int32 i = 0; i < num_bins; i++) { - int32 offset = bins_[i].first; - const Vector &v(bins_[i].second); - BaseFloat energy = VecVec(v, power_spectrum.Range(offset, v.Dim())); - // HTK-like flooring- for testing purposes (we prefer dither) - if (htk_mode_ && energy < 1.0) energy = 1.0; - (*mel_energies_out)(i) = energy; - - // The following assert was added due to a problem with OpenBlas that - // we had at one point (it was a bug in that library). Just to detect - // it early. - KALDI_ASSERT(!KALDI_ISNAN((*mel_energies_out)(i))); - } - - if (debug_) { - fprintf(stderr, "MEL BANKS:\n"); - for (int32 i = 0; i < num_bins; i++) - fprintf(stderr, " %f", (*mel_energies_out)(i)); - fprintf(stderr, "\n"); - } -} - -void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs) { - // Compute liftering coefficients (scaling on cepstral coeffs) - // coeffs are numbered slightly differently from HTK: the zeroth - // index is C0, which is not affected. - for (int32 i = 0; i < coeffs->Dim(); i++) - (*coeffs)(i) = 1.0 + 0.5 * Q * sin (M_PI * i / Q); -} - - -// Durbin's recursion - converts autocorrelation coefficients to the LPC -// pTmp - temporal place [n] -// pAC - autocorrelation coefficients [n + 1] -// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i-1] * s[n-i]}}) -// F(z) = 1 / (1 - A(z)), 1 is not stored in the demoninator -BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp) { - BaseFloat ki; // reflection coefficient - int i; - int j; - - BaseFloat E = pAC[0]; - - for (i = 0; i < n; i++) { - // next reflection coefficient - ki = pAC[i + 1]; - for (j = 0; j < i; j++) - ki += pLP[j] * pAC[i - j]; - ki = ki / E; - - // new error - BaseFloat c = 1 - ki * ki; - if (c < 1.0e-5) // remove NaNs for constan signal - c = 1.0e-5; - E *= c; - - // new LP coefficients - pTmp[i] = -ki; - for (j = 0; j < i; j++) - pTmp[j] = pLP[j] - ki * pLP[i - j - 1]; - - for (j = 0; j <= i; j++) - pLP[j] = pTmp[j]; - } - - return E; -} - - -void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst) { - for (int32 i = 0; i < n; i++) { - double sum = 0.0; - int j; - for (j = 0; j < i; j++) { - sum += static_cast(i - j) * pLPC[j] * pCepst[i - j - 1]; - } - pCepst[i] = -pLPC[i] - sum / static_cast(i + 1); - } -} - -void GetEqualLoudnessVector(const MelBanks &mel_banks, - Vector *ans) { - int32 n = mel_banks.NumBins(); - // Central frequency of each mel bin. - const Vector &f0 = mel_banks.GetCenterFreqs(); - ans->Resize(n); - for (int32 i = 0; i < n; i++) { - BaseFloat fsq = f0(i) * f0(i); - BaseFloat fsub = fsq / (fsq + 1.6e5); - (*ans)(i) = fsub * fsub * ((fsq + 1.44e6) / (fsq + 9.61e6)); - } -} - - -// Compute LP coefficients from autocorrelation coefficients. -BaseFloat ComputeLpc(const VectorBase &autocorr_in, - Vector *lpc_out) { - int32 n = autocorr_in.Dim() - 1; - KALDI_ASSERT(lpc_out->Dim() == n); - Vector tmp(n); - BaseFloat ans = Durbin(n, autocorr_in.Data(), - lpc_out->Data(), - tmp.Data()); - if (ans <= 0.0) - KALDI_WARN << "Zero energy in LPC computation"; - return -Log(1.0 / ans); // forms the C0 value -} - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/mel-computations.h b/speechx/speechx/kaldi/feat/mel-computations.h deleted file mode 100644 index 0c1d41ca..00000000 --- a/speechx/speechx/kaldi/feat/mel-computations.h +++ /dev/null @@ -1,171 +0,0 @@ -// feat/mel-computations.h - -// Copyright 2009-2011 Phonexia s.r.o.; Microsoft Corporation -// 2016 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_MEL_COMPUTATIONS_H_ -#define KALDI_FEAT_MEL_COMPUTATIONS_H_ - -#include -#include -#include -#include -#include -#include - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "matrix/matrix-lib.h" - - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - -struct FrameExtractionOptions; // defined in feature-window.h - - -struct MelBanksOptions { - int32 num_bins; // e.g. 25; number of triangular bins - BaseFloat low_freq; // e.g. 20; lower frequency cutoff - BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative - // ->added to the Nyquist frequency to get the cutoff. - BaseFloat vtln_low; // vtln lower cutoff of warping function. - BaseFloat vtln_high; // vtln upper cutoff of warping function: if negative, added - // to the Nyquist frequency to get the cutoff. - bool debug_mel; - // htk_mode is a "hidden" config, it does not show up on command line. - // Enables more exact compatibility with HTK, for testing purposes. Affects - // mel-energy flooring and reproduces a bug in HTK. - bool htk_mode; - explicit MelBanksOptions(int num_bins = 25) - : num_bins(num_bins), low_freq(20), high_freq(0), vtln_low(100), - vtln_high(-500), debug_mel(false), htk_mode(false) {} - - void Register(OptionsItf *opts) { - opts->Register("num-mel-bins", &num_bins, - "Number of triangular mel-frequency bins"); - opts->Register("low-freq", &low_freq, - "Low cutoff frequency for mel bins"); - opts->Register("high-freq", &high_freq, - "High cutoff frequency for mel bins (if <= 0, offset from Nyquist)"); - opts->Register("vtln-low", &vtln_low, - "Low inflection point in piecewise linear VTLN warping function"); - opts->Register("vtln-high", &vtln_high, - "High inflection point in piecewise linear VTLN warping function" - " (if negative, offset from high-mel-freq"); - opts->Register("debug-mel", &debug_mel, - "Print out debugging information for mel bin computation"); - } -}; - - -class MelBanks { - public: - - static inline BaseFloat InverseMelScale(BaseFloat mel_freq) { - return 700.0f * (expf (mel_freq / 1127.0f) - 1.0f); - } - - static inline BaseFloat MelScale(BaseFloat freq) { - return 1127.0f * logf (1.0f + freq / 700.0f); - } - - static BaseFloat VtlnWarpFreq(BaseFloat vtln_low_cutoff, - BaseFloat vtln_high_cutoff, // discontinuities in warp func - BaseFloat low_freq, - BaseFloat high_freq, // upper+lower frequency cutoffs in - // the mel computation - BaseFloat vtln_warp_factor, - BaseFloat freq); - - static BaseFloat VtlnWarpMelFreq(BaseFloat vtln_low_cutoff, - BaseFloat vtln_high_cutoff, - BaseFloat low_freq, - BaseFloat high_freq, - BaseFloat vtln_warp_factor, - BaseFloat mel_freq); - - - MelBanks(const MelBanksOptions &opts, - const FrameExtractionOptions &frame_opts, - BaseFloat vtln_warp_factor); - - /// Compute Mel energies (note: not log enerties). - /// At input, "fft_energies" contains the FFT energies (not log). - void Compute(const VectorBase &fft_energies, - VectorBase *mel_energies_out) const; - - int32 NumBins() const { return bins_.size(); } - - // returns vector of central freq of each bin; needed by plp code. - const Vector &GetCenterFreqs() const { return center_freqs_; } - - const std::vector > >& GetBins() const { - return bins_; - } - - // Copy constructor - MelBanks(const MelBanks &other); - private: - // Disallow assignment - MelBanks &operator = (const MelBanks &other); - - // center frequencies of bins, numbered from 0 ... num_bins-1. - // Needed by GetCenterFreqs(). - Vector center_freqs_; - - // the "bins_" vector is a vector, one for each bin, of a pair: - // (the first nonzero fft-bin), (the vector of weights). - std::vector > > bins_; - - bool debug_; - bool htk_mode_; -}; - - -// Compute liftering coefficients (scaling on cepstral coeffs) -// coeffs are numbered slightly differently from HTK: the zeroth -// index is C0, which is not affected. -void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs); - - -// Durbin's recursion - converts autocorrelation coefficients to the LPC -// pTmp - temporal place [n] -// pAC - autocorrelation coefficients [n + 1] -// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i-1] * s[n-i]}}) -// F(z) = 1 / (1 - A(z)), 1 is not stored in the denominator -// Returns log energy of residual (I think) -BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp); - -// Compute LP coefficients from autocorrelation coefficients. -// Returns log energy of residual (I think) -BaseFloat ComputeLpc(const VectorBase &autocorr_in, - Vector *lpc_out); - -void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst); - - - -void GetEqualLoudnessVector(const MelBanks &mel_banks, - Vector *ans); - -/// @} End of "addtogroup feat" -} // namespace kaldi - -#endif // KALDI_FEAT_MEL_COMPUTATIONS_H_ diff --git a/speechx/speechx/kaldi/feat/online-feature-itf.h b/speechx/speechx/kaldi/feat/online-feature-itf.h deleted file mode 100644 index a0211c09..00000000 --- a/speechx/speechx/kaldi/feat/online-feature-itf.h +++ /dev/null @@ -1,125 +0,0 @@ -// feat/online-feature-itf.h - -// Copyright 2013 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_ONLINE_FEATURE_ITF_H_ -#define KALDI_FEAT_ONLINE_FEATURE_ITF_H_ 1 -#include "base/kaldi-common.h" -#include "matrix/matrix-lib.h" - -namespace kaldi { -/// @ingroup Interfaces -/// @{ - -/** - OnlineFeatureInterface is an interface for online feature processing (it is - also usable in the offline setting, but currently we're not using it for - that). This is for use in the online2/ directory, and it supersedes the - interface in ../online/online-feat-input.h. We have a slightly different - model that puts more control in the hands of the calling thread, and won't - involve waiting on semaphores in the decoding thread. - - This interface only specifies how the object *outputs* the features. - How it obtains the features, e.g. from a previous object or objects of type - OnlineFeatureInterface, is not specified in the interface and you will - likely define new constructors or methods in the derived type to do that. - - You should appreciate that this interface is designed to allow random - access to features, as long as they are ready. That is, the user - can call GetFrame for any frame less than NumFramesReady(), and when - implementing a child class you must not make assumptions about the - order in which the user makes these calls. -*/ - -class OnlineFeatureInterface { - public: - virtual int32 Dim() const = 0; /// returns the feature dimension. - - /// Returns the total number of frames, since the start of the utterance, that - /// are now available. In an online-decoding context, this will likely - /// increase with time as more data becomes available. - virtual int32 NumFramesReady() const = 0; - - /// Returns true if this is the last frame. Frame indices are zero-based, so the - /// first frame is zero. IsLastFrame(-1) will return false, unless the file - /// is empty (which is a case that I'm not sure all the code will handle, so - /// be careful). This function may return false for some frame if - /// we haven't yet decided to terminate decoding, but later true if we decide - /// to terminate decoding. This function exists mainly to correctly handle - /// end effects in feature extraction, and is not a mechanism to determine how - /// many frames are in the decodable object (as it used to be, and for backward - /// compatibility, still is, in the Decodable interface). - virtual bool IsLastFrame(int32 frame) const = 0; - - /// Gets the feature vector for this frame. Before calling this for a given - /// frame, it is assumed that you called NumFramesReady() and it returned a - /// number greater than "frame". Otherwise this call will likely crash with - /// an assert failure. This function is not declared const, in case there is - /// some kind of caching going on, but most of the time it shouldn't modify - /// the class. - virtual void GetFrame(int32 frame, VectorBase *feat) = 0; - - - /// This is like GetFrame() but for a collection of frames. There is a - /// default implementation that just gets the frames one by one, but it - /// may be overridden for efficiency by child classes (since sometimes - /// it's more efficient to do things in a batch). - virtual void GetFrames(const std::vector &frames, - MatrixBase *feats) { - KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); - for (size_t i = 0; i < frames.size(); i++) { - SubVector feat(*feats, i); - GetFrame(frames[i], &feat); - } - } - - - // Returns frame shift in seconds. Helps to estimate duration from frame - // counts. - virtual BaseFloat FrameShiftInSeconds() const = 0; - - /// Virtual destructor. Note: constructors that take another member of - /// type OnlineFeatureInterface are not expected to take ownership of - /// that pointer; the caller needs to keep track of that manually. - virtual ~OnlineFeatureInterface() { } - -}; - - -/// Add a virtual class for "source" features such as MFCC or PLP or pitch -/// features. -class OnlineBaseFeature: public OnlineFeatureInterface { - public: - /// This would be called from the application, when you get more wave data. - /// Note: the sampling_rate is typically only provided so the code can assert - /// that it matches the sampling rate expected in the options. - virtual void AcceptWaveform(BaseFloat sampling_rate, - const VectorBase &waveform) = 0; - - /// InputFinished() tells the class you won't be providing any - /// more waveform. This will help flush out the last few frames - /// of delta or LDA features (it will typically affect the return value - /// of IsLastFrame. - virtual void InputFinished() = 0; -}; - - -/// @} -} // namespace Kaldi - -#endif // KALDI_ITF_ONLINE_FEATURE_ITF_H_ diff --git a/speechx/speechx/kaldi/feat/online-feature.cc b/speechx/speechx/kaldi/feat/online-feature.cc deleted file mode 100644 index 047909e7..00000000 --- a/speechx/speechx/kaldi/feat/online-feature.cc +++ /dev/null @@ -1,679 +0,0 @@ -// feat/online-feature.cc - -// Copyright 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Yanqing Sun, Junjie Wang, -// Daniel Povey, Korbinian Riedhammer - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "feat/online-feature.h" -#include "transform/cmvn.h" - -namespace kaldi { - -RecyclingVector::RecyclingVector(int items_to_hold): - items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold), - first_available_index_(0) { -} - -RecyclingVector::~RecyclingVector() { - for (auto *item : items_) { - delete item; - } -} - -Vector *RecyclingVector::At(int index) const { - if (index < first_available_index_) { - KALDI_ERR << "Attempted to retrieve feature vector that was " - "already removed by the RecyclingVector (index = " - << index << "; " - << "first_available_index = " << first_available_index_ << "; " - << "size = " << Size() << ")"; - } - // 'at' does size checking. - return items_.at(index - first_available_index_); -} - -void RecyclingVector::PushBack(Vector *item) { - if (items_.size() == items_to_hold_) { - delete items_.front(); - items_.pop_front(); - ++first_available_index_; - } - items_.push_back(item); -} - -int RecyclingVector::Size() const { - return first_available_index_ + items_.size(); -} - -template -void OnlineGenericBaseFeature::GetFrame(int32 frame, - VectorBase *feat) { - feat->CopyFromVec(*(features_.At(frame))); -}; - -template -OnlineGenericBaseFeature::OnlineGenericBaseFeature( - const typename C::Options &opts): - computer_(opts), window_function_(computer_.GetFrameOptions()), - features_(opts.frame_opts.max_feature_vectors), - input_finished_(false), waveform_offset_(0) { - // RE the following assert: search for ONLINE_IVECTOR_LIMIT in - // online-ivector-feature.cc. - // Casting to uint32, an unsigned type, means that -1 would be treated - // as `very large`. - KALDI_ASSERT(static_cast(opts.frame_opts.max_feature_vectors) > 200); -} - - -template -void OnlineGenericBaseFeature::MaybeCreateResampler( - BaseFloat sampling_rate) { - BaseFloat expected_sampling_rate = computer_.GetFrameOptions().samp_freq; - - if (resampler_ != nullptr) { - KALDI_ASSERT(resampler_->GetInputSamplingRate() == sampling_rate); - KALDI_ASSERT(resampler_->GetOutputSamplingRate() == expected_sampling_rate); - } else if (((sampling_rate < expected_sampling_rate) && - computer_.GetFrameOptions().allow_downsample) || - ((sampling_rate > expected_sampling_rate) && - computer_.GetFrameOptions().allow_upsample)) { - resampler_.reset(new LinearResample( - sampling_rate, expected_sampling_rate, - std::min(sampling_rate / 2, expected_sampling_rate / 2), 6)); - } else if (sampling_rate != expected_sampling_rate) { - KALDI_ERR << "Sampling frequency mismatch, expected " - << expected_sampling_rate << ", got " << sampling_rate - << "\nPerhaps you want to use the options " - "--allow_{upsample,downsample}"; - } -} - -template -void OnlineGenericBaseFeature::InputFinished() { - if (resampler_ != nullptr) { - // There may be a few samples left once we flush the resampler_ object, telling it - // that the file has finished. This should rarely make any difference. - Vector appended_wave; - Vector resampled_wave; - resampler_->Resample(appended_wave, true, &resampled_wave); - - if (resampled_wave.Dim() != 0) { - appended_wave.Resize(waveform_remainder_.Dim() + - resampled_wave.Dim()); - if (waveform_remainder_.Dim() != 0) - appended_wave.Range(0, waveform_remainder_.Dim()) - .CopyFromVec(waveform_remainder_); - appended_wave.Range(waveform_remainder_.Dim(), resampled_wave.Dim()) - .CopyFromVec(resampled_wave); - waveform_remainder_.Swap(&appended_wave); - } - } - input_finished_ = true; - ComputeFeatures(); -} - -template -void OnlineGenericBaseFeature::AcceptWaveform( - BaseFloat sampling_rate, const VectorBase &original_waveform) { - if (original_waveform.Dim() == 0) - return; // Nothing to do. - if (input_finished_) - KALDI_ERR << "AcceptWaveform called after InputFinished() was called."; - - Vector appended_wave; - Vector resampled_wave; - - const VectorBase *waveform; - - MaybeCreateResampler(sampling_rate); - if (resampler_ == nullptr) { - waveform = &original_waveform; - } else { - resampler_->Resample(original_waveform, false, &resampled_wave); - waveform = &resampled_wave; - } - - appended_wave.Resize(waveform_remainder_.Dim() + waveform->Dim()); - if (waveform_remainder_.Dim() != 0) - appended_wave.Range(0, waveform_remainder_.Dim()) - .CopyFromVec(waveform_remainder_); - appended_wave.Range(waveform_remainder_.Dim(), waveform->Dim()) - .CopyFromVec(*waveform); - waveform_remainder_.Swap(&appended_wave); - ComputeFeatures(); -} - -template -void OnlineGenericBaseFeature::ComputeFeatures() { - const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); - int64 num_samples_total = waveform_offset_ + waveform_remainder_.Dim(); - int32 num_frames_old = features_.Size(), - num_frames_new = NumFrames(num_samples_total, frame_opts, - input_finished_); - KALDI_ASSERT(num_frames_new >= num_frames_old); - - Vector window; - bool need_raw_log_energy = computer_.NeedRawLogEnergy(); - for (int32 frame = num_frames_old; frame < num_frames_new; frame++) { - BaseFloat raw_log_energy = 0.0; - ExtractWindow(waveform_offset_, waveform_remainder_, frame, - frame_opts, window_function_, &window, - need_raw_log_energy ? &raw_log_energy : NULL); - Vector *this_feature = new Vector(computer_.Dim(), - kUndefined); - // note: this online feature-extraction code does not support VTLN. - BaseFloat vtln_warp = 1.0; - computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature); - features_.PushBack(this_feature); - } - // OK, we will now discard any portion of the signal that will not be - // necessary to compute frames in the future. - int64 first_sample_of_next_frame = FirstSampleOfFrame(num_frames_new, - frame_opts); - int32 samples_to_discard = first_sample_of_next_frame - waveform_offset_; - if (samples_to_discard > 0) { - // discard the leftmost part of the waveform that we no longer need. - int32 new_num_samples = waveform_remainder_.Dim() - samples_to_discard; - if (new_num_samples <= 0) { - // odd, but we'll try to handle it. - waveform_offset_ += waveform_remainder_.Dim(); - waveform_remainder_.Resize(0); - } else { - Vector new_remainder(new_num_samples); - new_remainder.CopyFromVec(waveform_remainder_.Range(samples_to_discard, - new_num_samples)); - waveform_offset_ += samples_to_discard; - waveform_remainder_.Swap(&new_remainder); - } - } -} - -// instantiate the templates defined here for MFCC, PLP and filterbank classes. -template class OnlineGenericBaseFeature; -template class OnlineGenericBaseFeature; -template class OnlineGenericBaseFeature; - -OnlineCmvnState::OnlineCmvnState(const OnlineCmvnState &other): - speaker_cmvn_stats(other.speaker_cmvn_stats), - global_cmvn_stats(other.global_cmvn_stats), - frozen_state(other.frozen_state) { } - -void OnlineCmvnState::Write(std::ostream &os, bool binary) const { - WriteToken(os, binary, ""); // magic string. - WriteToken(os, binary, ""); - speaker_cmvn_stats.Write(os, binary); - WriteToken(os, binary, ""); - global_cmvn_stats.Write(os, binary); - WriteToken(os, binary, ""); - frozen_state.Write(os, binary); - WriteToken(os, binary, ""); -} - -void OnlineCmvnState::Read(std::istream &is, bool binary) { - ExpectToken(is, binary, ""); // magic string. - ExpectToken(is, binary, ""); - speaker_cmvn_stats.Read(is, binary); - ExpectToken(is, binary, ""); - global_cmvn_stats.Read(is, binary); - ExpectToken(is, binary, ""); - frozen_state.Read(is, binary); - ExpectToken(is, binary, ""); -} - -OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, - const OnlineCmvnState &cmvn_state, - OnlineFeatureInterface *src): - opts_(opts), temp_stats_(2, src->Dim() + 1), - temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), - src_(src) { - SetState(cmvn_state); - if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) - KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " - << "integers)"; -} - -OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, - OnlineFeatureInterface *src): - opts_(opts), temp_stats_(2, src->Dim() + 1), - temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), - src_(src) { - if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) - KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " - << "integers)"; -} - - -void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, - int32 *cached_frame, - MatrixBase *stats) { - KALDI_ASSERT(frame >= 0); - InitRingBufferIfNeeded(); - // look for a cached frame on a previous frame as close as possible in time - // to "frame". Return if we get one. - for (int32 t = frame; t >= 0 && t >= frame - opts_.ring_buffer_size; t--) { - if (t % opts_.modulus == 0) { - // if this frame should be cached in cached_stats_modulo_, then - // we'll look there, and we won't go back any further in time. - break; - } - int32 index = t % opts_.ring_buffer_size; - if (cached_stats_ring_[index].first == t) { - *cached_frame = t; - stats->CopyFromMat(cached_stats_ring_[index].second); - return; - } - } - int32 n = frame / opts_.modulus; - if (n >= cached_stats_modulo_.size()) { - if (cached_stats_modulo_.size() == 0) { - *cached_frame = -1; - stats->SetZero(); - return; - } else { - n = static_cast(cached_stats_modulo_.size() - 1); - } - } - *cached_frame = n * opts_.modulus; - KALDI_ASSERT(cached_stats_modulo_[n] != NULL); - stats->CopyFromMat(*(cached_stats_modulo_[n])); -} - -// Initialize ring buffer for caching stats. -void OnlineCmvn::InitRingBufferIfNeeded() { - if (cached_stats_ring_.empty() && opts_.ring_buffer_size > 0) { - Matrix temp(2, this->Dim() + 1); - cached_stats_ring_.resize(opts_.ring_buffer_size, - std::pair >(-1, temp)); - } -} - -void OnlineCmvn::CacheFrame(int32 frame, const MatrixBase &stats) { - KALDI_ASSERT(frame >= 0); - if (frame % opts_.modulus == 0) { // store in cached_stats_modulo_. - int32 n = frame / opts_.modulus; - if (n >= cached_stats_modulo_.size()) { - // The following assert is a limitation on in what order you can call - // CacheFrame. Fortunately the calling code always calls it in sequence, - // which it has to because you need a previous frame to compute the - // current one. - KALDI_ASSERT(n == cached_stats_modulo_.size()); - cached_stats_modulo_.push_back(new Matrix(stats)); - } else { - KALDI_WARN << "Did not expect to reach this part of code."; - // do what seems right, but we shouldn't get here. - cached_stats_modulo_[n]->CopyFromMat(stats); - } - } else { // store in the ring buffer. - InitRingBufferIfNeeded(); - if (!cached_stats_ring_.empty()) { - int32 index = frame % cached_stats_ring_.size(); - cached_stats_ring_[index].first = frame; - cached_stats_ring_[index].second.CopyFromMat(stats); - } - } -} - -OnlineCmvn::~OnlineCmvn() { - for (size_t i = 0; i < cached_stats_modulo_.size(); i++) - delete cached_stats_modulo_[i]; - cached_stats_modulo_.clear(); -} - -void OnlineCmvn::ComputeStatsForFrame(int32 frame, - MatrixBase *stats_out) { - KALDI_ASSERT(frame >= 0 && frame < src_->NumFramesReady()); - - int32 dim = this->Dim(), cur_frame; - GetMostRecentCachedFrame(frame, &cur_frame, stats_out); - - Vector &feats(temp_feats_); - Vector &feats_dbl(temp_feats_dbl_); - while (cur_frame < frame) { - cur_frame++; - src_->GetFrame(cur_frame, &feats); - feats_dbl.CopyFromVec(feats); - stats_out->Row(0).Range(0, dim).AddVec(1.0, feats_dbl); - if (opts_.normalize_variance) - stats_out->Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); - (*stats_out)(0, dim) += 1.0; - // it's a sliding buffer; a frame at the back may be - // leaving the buffer so we have to subtract that. - int32 prev_frame = cur_frame - opts_.cmn_window; - if (prev_frame >= 0) { - // we need to subtract frame prev_f from the stats. - src_->GetFrame(prev_frame, &feats); - feats_dbl.CopyFromVec(feats); - stats_out->Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); - if (opts_.normalize_variance) - stats_out->Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); - (*stats_out)(0, dim) -= 1.0; - } - CacheFrame(cur_frame, (*stats_out)); - } -} - - -// static -void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, - const MatrixBase &global_stats, - const OnlineCmvnOptions &opts, - MatrixBase *stats) { - if (speaker_stats.NumRows() == 2 && !opts.normalize_variance) { - // this is just for efficiency: don't operate on the variance if it's not - // needed. - int32 cols = speaker_stats.NumCols(); // dim + 1 - SubMatrix stats_temp(*stats, 0, 1, 0, cols); - SmoothOnlineCmvnStats(speaker_stats.RowRange(0, 1), - global_stats.RowRange(0, 1), - opts, &stats_temp); - return; - } - int32 dim = stats->NumCols() - 1; - double cur_count = (*stats)(0, dim); - // If count exceeded cmn_window it would be an error in how "window_stats" - // was accumulated. - KALDI_ASSERT(cur_count <= 1.001 * opts.cmn_window); - if (cur_count >= opts.cmn_window) - return; - if (speaker_stats.NumRows() != 0) { // if we have speaker stats.. - double count_from_speaker = opts.cmn_window - cur_count, - speaker_count = speaker_stats(0, dim); - if (count_from_speaker > opts.speaker_frames) - count_from_speaker = opts.speaker_frames; - if (count_from_speaker > speaker_count) - count_from_speaker = speaker_count; - if (count_from_speaker > 0.0) - stats->AddMat(count_from_speaker / speaker_count, - speaker_stats); - cur_count = (*stats)(0, dim); - } - if (cur_count >= opts.cmn_window) - return; - if (global_stats.NumRows() != 0) { - double count_from_global = opts.cmn_window - cur_count, - global_count = global_stats(0, dim); - KALDI_ASSERT(global_count > 0.0); - if (count_from_global > opts.global_frames) - count_from_global = opts.global_frames; - if (count_from_global > 0.0) - stats->AddMat(count_from_global / global_count, - global_stats); - } else { - KALDI_ERR << "Global CMN stats are required"; - } -} - -void OnlineCmvn::GetFrame(int32 frame, - VectorBase *feat) { - src_->GetFrame(frame, feat); - KALDI_ASSERT(feat->Dim() == this->Dim()); - int32 dim = feat->Dim(); - Matrix &stats(temp_stats_); - stats.Resize(2, dim + 1, kUndefined); // Will do nothing if size was correct. - if (frozen_state_.NumRows() != 0) { // the CMVN state has been frozen. - stats.CopyFromMat(frozen_state_); - } else { - // first get the raw CMVN stats (this involves caching..) - this->ComputeStatsForFrame(frame, &stats); - // now smooth them. - SmoothOnlineCmvnStats(orig_state_.speaker_cmvn_stats, - orig_state_.global_cmvn_stats, - opts_, - &stats); - } - - if (!skip_dims_.empty()) - FakeStatsForSomeDims(skip_dims_, &stats); - - // call the function ApplyCmvn declared in ../transform/cmvn.h, which - // requires a matrix. - // 1 row; num-cols == dim; stride == dim. - SubMatrix feat_mat(feat->Data(), 1, dim, dim); - // the function ApplyCmvn takes a matrix, so form a one-row matrix to give it. - if (opts_.normalize_mean) - ApplyCmvn(stats, opts_.normalize_variance, &feat_mat); - else - KALDI_ASSERT(!opts_.normalize_variance); -} - -void OnlineCmvn::Freeze(int32 cur_frame) { - int32 dim = this->Dim(); - Matrix stats(2, dim + 1); - // get the raw CMVN stats - this->ComputeStatsForFrame(cur_frame, &stats); - // now smooth them. - SmoothOnlineCmvnStats(orig_state_.speaker_cmvn_stats, - orig_state_.global_cmvn_stats, - opts_, - &stats); - this->frozen_state_ = stats; -} - -void OnlineCmvn::GetState(int32 cur_frame, - OnlineCmvnState *state_out) { - *state_out = this->orig_state_; - { // This block updates state_out->speaker_cmvn_stats - int32 dim = this->Dim(); - if (state_out->speaker_cmvn_stats.NumRows() == 0) - state_out->speaker_cmvn_stats.Resize(2, dim + 1); - Vector feat(dim); - Vector feat_dbl(dim); - for (int32 t = 0; t <= cur_frame; t++) { - src_->GetFrame(t, &feat); - feat_dbl.CopyFromVec(feat); - state_out->speaker_cmvn_stats(0, dim) += 1.0; - state_out->speaker_cmvn_stats.Row(0).Range(0, dim).AddVec(1.0, feat_dbl); - state_out->speaker_cmvn_stats.Row(1).Range(0, dim).AddVec2(1.0, feat_dbl); - } - } - // Store any frozen state (the effect of the user possibly - // having called Freeze(). - state_out->frozen_state = frozen_state_; -} - -void OnlineCmvn::SetState(const OnlineCmvnState &cmvn_state) { - KALDI_ASSERT(cached_stats_modulo_.empty() && - "You cannot call SetState() after processing data."); - orig_state_ = cmvn_state; - frozen_state_ = cmvn_state.frozen_state; -} - -int32 OnlineSpliceFrames::NumFramesReady() const { - int32 num_frames = src_->NumFramesReady(); - if (num_frames > 0 && src_->IsLastFrame(num_frames - 1)) - return num_frames; - else - return std::max(0, num_frames - right_context_); -} - -void OnlineSpliceFrames::GetFrame(int32 frame, VectorBase *feat) { - KALDI_ASSERT(left_context_ >= 0 && right_context_ >= 0); - KALDI_ASSERT(frame >= 0 && frame < NumFramesReady()); - int32 dim_in = src_->Dim(); - KALDI_ASSERT(feat->Dim() == dim_in * (1 + left_context_ + right_context_)); - int32 T = src_->NumFramesReady(); - for (int32 t2 = frame - left_context_; t2 <= frame + right_context_; t2++) { - int32 t2_limited = t2; - if (t2_limited < 0) t2_limited = 0; - if (t2_limited >= T) t2_limited = T - 1; - int32 n = t2 - (frame - left_context_); // 0 for left-most frame, - // increases to the right. - SubVector part(*feat, n * dim_in, dim_in); - src_->GetFrame(t2_limited, &part); - } -} - -OnlineTransform::OnlineTransform(const MatrixBase &transform, - OnlineFeatureInterface *src): - src_(src) { - int32 src_dim = src_->Dim(); - if (transform.NumCols() == src_dim) { // Linear transform - linear_term_ = transform; - offset_.Resize(transform.NumRows()); // Resize() will zero it. - } else if (transform.NumCols() == src_dim + 1) { // Affine transform - linear_term_ = transform.Range(0, transform.NumRows(), 0, src_dim); - offset_.Resize(transform.NumRows()); - offset_.CopyColFromMat(transform, src_dim); - } else { - KALDI_ERR << "Dimension mismatch: source features have dimension " - << src_dim << " and LDA #cols is " << transform.NumCols(); - } -} - -void OnlineTransform::GetFrame(int32 frame, VectorBase *feat) { - Vector input_feat(linear_term_.NumCols()); - src_->GetFrame(frame, &input_feat); - feat->CopyFromVec(offset_); - feat->AddMatVec(1.0, linear_term_, kNoTrans, input_feat, 1.0); -} - -void OnlineTransform::GetFrames( - const std::vector &frames, MatrixBase *feats) { - KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); - int32 num_frames = feats->NumRows(), - input_dim = linear_term_.NumCols(); - Matrix input_feats(num_frames, input_dim, kUndefined); - src_->GetFrames(frames, &input_feats); - feats->CopyRowsFromVec(offset_); - feats->AddMatMat(1.0, input_feats, kNoTrans, linear_term_, kTrans, 1.0); -} - - -int32 OnlineDeltaFeature::Dim() const { - int32 src_dim = src_->Dim(); - return src_dim * (1 + opts_.order); -} - -int32 OnlineDeltaFeature::NumFramesReady() const { - int32 num_frames = src_->NumFramesReady(), - context = opts_.order * opts_.window; - // "context" is the number of frames on the left or (more relevant - // here) right which we need in order to produce the output. - if (num_frames > 0 && src_->IsLastFrame(num_frames-1)) - return num_frames; - else - return std::max(0, num_frames - context); -} - -void OnlineDeltaFeature::GetFrame(int32 frame, - VectorBase *feat) { - KALDI_ASSERT(frame >= 0 && frame < NumFramesReady()); - KALDI_ASSERT(feat->Dim() == Dim()); - // We'll produce a temporary matrix containing the features we want to - // compute deltas on, but truncated to the necessary context. - int32 context = opts_.order * opts_.window; - int32 left_frame = frame - context, - right_frame = frame + context, - src_frames_ready = src_->NumFramesReady(); - if (left_frame < 0) left_frame = 0; - if (right_frame >= src_frames_ready) - right_frame = src_frames_ready - 1; - KALDI_ASSERT(right_frame >= left_frame); - int32 temp_num_frames = right_frame + 1 - left_frame, - src_dim = src_->Dim(); - Matrix temp_src(temp_num_frames, src_dim); - for (int32 t = left_frame; t <= right_frame; t++) { - SubVector temp_row(temp_src, t - left_frame); - src_->GetFrame(t, &temp_row); - } - int32 temp_t = frame - left_frame; // temp_t is the offset of frame "frame" - // within temp_src - delta_features_.Process(temp_src, temp_t, feat); -} - - -OnlineDeltaFeature::OnlineDeltaFeature(const DeltaFeaturesOptions &opts, - OnlineFeatureInterface *src): - src_(src), opts_(opts), delta_features_(opts) { } - -void OnlineCacheFeature::GetFrame(int32 frame, VectorBase *feat) { - KALDI_ASSERT(frame >= 0); - if (static_cast(frame) < cache_.size() && cache_[frame] != NULL) { - feat->CopyFromVec(*(cache_[frame])); - } else { - if (static_cast(frame) >= cache_.size()) - cache_.resize(frame + 1, NULL); - int32 dim = this->Dim(); - cache_[frame] = new Vector(dim); - // The following call will crash if frame "frame" is not ready. - src_->GetFrame(frame, cache_[frame]); - feat->CopyFromVec(*(cache_[frame])); - } -} - -void OnlineCacheFeature::GetFrames( - const std::vector &frames, MatrixBase *feats) { - int32 num_frames = frames.size(); - // non_cached_frames will be the subset of 't' values in 'frames' which were - // not previously cached, which we therefore need to get from src_. - std::vector non_cached_frames; - // 'non_cached_indexes' stores the indexes 'i' into 'frames' corresponding to - // the corresponding frames in 'non_cached_frames'. - std::vector non_cached_indexes; - non_cached_frames.reserve(frames.size()); - non_cached_indexes.reserve(frames.size()); - for (int32 i = 0; i < num_frames; i++) { - int32 t = frames[i]; - if (static_cast(t) < cache_.size() && cache_[t] != NULL) { - feats->Row(i).CopyFromVec(*(cache_[t])); - } else { - non_cached_frames.push_back(t); - non_cached_indexes.push_back(i); - } - } - if (non_cached_frames.empty()) - return; - int32 num_non_cached_frames = non_cached_frames.size(), - dim = this->Dim(); - Matrix non_cached_feats(num_non_cached_frames, dim, - kUndefined); - src_->GetFrames(non_cached_frames, &non_cached_feats); - for (int32 i = 0; i < num_non_cached_frames; i++) { - int32 t = non_cached_frames[i]; - if (static_cast(t) < cache_.size() && cache_[t] != NULL) { - // We can reach this point due to repeat indexes in 'non_cached_frames'. - feats->Row(non_cached_indexes[i]).CopyFromVec(*(cache_[t])); - } else { - SubVector this_feat(non_cached_feats, i); - feats->Row(non_cached_indexes[i]).CopyFromVec(this_feat); - if (static_cast(t) >= cache_.size()) - cache_.resize(t + 1, NULL); - cache_[t] = new Vector(this_feat); - } - } -} - - -void OnlineCacheFeature::ClearCache() { - for (size_t i = 0; i < cache_.size(); i++) - delete cache_[i]; - cache_.resize(0); -} - - -void OnlineAppendFeature::GetFrame(int32 frame, VectorBase *feat) { - KALDI_ASSERT(feat->Dim() == Dim()); - - SubVector feat1(*feat, 0, src1_->Dim()); - SubVector feat2(*feat, src1_->Dim(), src2_->Dim()); - src1_->GetFrame(frame, &feat1); - src2_->GetFrame(frame, &feat2); -}; - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/online-feature.h b/speechx/speechx/kaldi/feat/online-feature.h deleted file mode 100644 index f9b26ecc..00000000 --- a/speechx/speechx/kaldi/feat/online-feature.h +++ /dev/null @@ -1,632 +0,0 @@ -// feat/online-feature.h - -// Copyright 2013 Johns Hopkins University (author: Daniel Povey) -// 2014 Yanqing Sun, Junjie Wang, -// Daniel Povey, Korbinian Riedhammer - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#ifndef KALDI_FEAT_ONLINE_FEATURE_H_ -#define KALDI_FEAT_ONLINE_FEATURE_H_ - -#include -#include -#include - -#include "matrix/matrix-lib.h" -#include "util/common-utils.h" -#include "base/kaldi-error.h" -#include "feat/feature-functions.h" -#include "feat/feature-mfcc.h" -#include "feat/feature-plp.h" -#include "feat/feature-fbank.h" -#include "feat/online-feature-itf.h" - -namespace kaldi { -/// @addtogroup onlinefeat OnlineFeatureExtraction -/// @{ - - -/// This class serves as a storage for feature vectors with an option to limit -/// the memory usage by removing old elements. The deleted frames indices are -/// "remembered" so that regardless of the MAX_ITEMS setting, the user always -/// provides the indices as if no deletion was being performed. -/// This is useful when processing very long recordings which would otherwise -/// cause the memory to eventually blow up when the features are not being removed. -class RecyclingVector { -public: - /// By default it does not remove any elements. - RecyclingVector(int items_to_hold = -1); - - /// The ownership is being retained by this collection - do not delete the item. - Vector *At(int index) const; - - /// The ownership of the item is passed to this collection - do not delete the item. - void PushBack(Vector *item); - - /// This method returns the size as if no "recycling" had happened, - /// i.e. equivalent to the number of times the PushBack method has been called. - int Size() const; - - ~RecyclingVector(); - -private: - std::deque*> items_; - int items_to_hold_; - int first_available_index_; -}; - - -/// This is a templated class for online feature extraction; -/// it's templated on a class like MfccComputer or PlpComputer -/// that does the basic feature extraction. -template -class OnlineGenericBaseFeature: public OnlineBaseFeature { - public: - // - // First, functions that are present in the interface: - // - virtual int32 Dim() const { return computer_.Dim(); } - - // Note: IsLastFrame() will only ever return true if you have called - // InputFinished() (and this frame is the last frame). - virtual bool IsLastFrame(int32 frame) const { - return input_finished_ && frame == NumFramesReady() - 1; - } - virtual BaseFloat FrameShiftInSeconds() const { - return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; - } - - virtual int32 NumFramesReady() const { return features_.Size(); } - - virtual void GetFrame(int32 frame, VectorBase *feat); - - // Next, functions that are not in the interface. - - - // Constructor from options class - explicit OnlineGenericBaseFeature(const typename C::Options &opts); - - // This would be called from the application, when you get - // more wave data. Note: the sampling_rate is only provided so - // the code can assert that it matches the sampling rate - // expected in the options. - virtual void AcceptWaveform(BaseFloat sampling_rate, - const VectorBase &waveform); - - - // InputFinished() tells the class you won't be providing any - // more waveform. This will help flush out the last frame or two - // of features, in the case where snip-edges == false; it also - // affects the return value of IsLastFrame(). - virtual void InputFinished(); - - private: - // This function computes any additional feature frames that it is possible to - // compute from 'waveform_remainder_', which at this point may contain more - // than just a remainder-sized quantity (because AcceptWaveform() appends to - // waveform_remainder_ before calling this function). It adds these feature - // frames to features_, and shifts off any now-unneeded samples of input from - // waveform_remainder_ while incrementing waveform_offset_ by the same amount. - void ComputeFeatures(); - - void MaybeCreateResampler(BaseFloat sampling_rate); - - C computer_; // class that does the MFCC or PLP or filterbank computation - - // resampler in cases when the input sampling frequency is not equal to - // the expected sampling rate - std::unique_ptr resampler_; - - FeatureWindowFunction window_function_; - - // features_ is the Mfcc or Plp or Fbank features that we have already computed. - - RecyclingVector features_; - - // True if the user has called "InputFinished()" - bool input_finished_; - - // The sampling frequency, extracted from the config. Should - // be identical to the waveform supplied. - BaseFloat sampling_frequency_; - - // waveform_offset_ is the number of samples of waveform that we have - // already discarded, i.e. that were prior to 'waveform_remainder_'. - int64 waveform_offset_; - - // waveform_remainder_ is a short piece of waveform that we may need to keep - // after extracting all the whole frames we can (whatever length of feature - // will be required for the next phase of computation). - Vector waveform_remainder_; -}; - -typedef OnlineGenericBaseFeature OnlineMfcc; -typedef OnlineGenericBaseFeature OnlinePlp; -typedef OnlineGenericBaseFeature OnlineFbank; - - -/// This class takes a Matrix and wraps it as an -/// OnlineFeatureInterface: this can be useful where some earlier stage of -/// feature processing has been done offline but you want to use part of the -/// online pipeline. -class OnlineMatrixFeature: public OnlineFeatureInterface { - public: - /// Caution: this class maintains the const reference from the constructor, so - /// don't let it go out of scope while this object exists. - explicit OnlineMatrixFeature(const MatrixBase &mat): mat_(mat) { } - - virtual int32 Dim() const { return mat_.NumCols(); } - - virtual BaseFloat FrameShiftInSeconds() const { - return 0.01f; - } - - virtual int32 NumFramesReady() const { return mat_.NumRows(); } - - virtual void GetFrame(int32 frame, VectorBase *feat) { - feat->CopyFromVec(mat_.Row(frame)); - } - - virtual bool IsLastFrame(int32 frame) const { - return (frame + 1 == mat_.NumRows()); - } - - - private: - const MatrixBase &mat_; -}; - - -// Note the similarity with SlidingWindowCmnOptions, but there -// are also differences. One which doesn't appear in the config -// itself, because it's a difference between the setups, is that -// in OnlineCmn, we carry over data from the previous utterance, -// or, if no previous utterance is available, from global stats, -// or, if previous utterances are available but the total amount -// of data is less than prev_frames, we pad with up to "global_frames" -// frames from the global stats. -struct OnlineCmvnOptions { - int32 cmn_window; - int32 speaker_frames; // must be <= cmn_window - int32 global_frames; // must be <= speaker_frames. - bool normalize_mean; // Must be true if normalize_variance==true. - bool normalize_variance; - - int32 modulus; // not configurable from command line, relates to how the - // class computes the cmvn internally. smaller->more - // time-efficient but less memory-efficient. Must be >= 1. - int32 ring_buffer_size; // not configurable from command line; size of ring - // buffer used for caching CMVN stats. Must be >= - // modulus. - std::string skip_dims; // Colon-separated list of dimensions to skip normalization - // of, e.g. 13:14:15. - - OnlineCmvnOptions(): - cmn_window(600), - speaker_frames(600), - global_frames(200), - normalize_mean(true), - normalize_variance(false), - modulus(20), - ring_buffer_size(20), - skip_dims("") { } - - void Check() const { - KALDI_ASSERT(speaker_frames <= cmn_window && global_frames <= speaker_frames - && modulus > 0); - } - - void Register(ParseOptions *po) { - po->Register("cmn-window", &cmn_window, "Number of frames of sliding " - "context for cepstral mean normalization."); - po->Register("global-frames", &global_frames, "Number of frames of " - "global-average cepstral mean normalization stats to use for " - "first utterance of a speaker"); - po->Register("speaker-frames", &speaker_frames, "Number of frames of " - "previous utterance(s) from this speaker to use in cepstral " - "mean normalization"); - // we name the config string "norm-vars" for compatibility with - // ../featbin/apply-cmvn.cc - po->Register("norm-vars", &normalize_variance, "If true, do " - "cepstral variance normalization in addition to cepstral mean " - "normalization "); - po->Register("norm-means", &normalize_mean, "If true, do mean normalization " - "(note: you cannot normalize the variance but not the mean)"); - po->Register("skip-dims", &skip_dims, "Dimensions to skip normalization of " - "(colon-separated list of integers)");} -}; - - - -/** Struct OnlineCmvnState stores the state of CMVN adaptation between - utterances (but not the state of the computation within an utterance). It - stores the global CMVN stats and the stats of the current speaker (if we - have seen previous utterances for this speaker), and possibly will have a - member "frozen_state": if the user has called the function Freeze() of class - OnlineCmvn, to fix the CMVN so we can estimate fMLLR on top of the fixed - value of cmvn. If nonempty, "frozen_state" will reflect how we were - normalizing the mean and (if applicable) variance at the time when that - function was called. -*/ -struct OnlineCmvnState { - // The following is the total CMVN stats for this speaker (up till now), in - // the same format. - Matrix speaker_cmvn_stats; - - // The following is the global CMVN stats, in the usual - // format, of dimension 2 x (dim+1), as [ sum-stats count - // sum-squared-stats 0 ] - Matrix global_cmvn_stats; - - // If nonempty, contains CMVN stats representing the "frozen" state - // of CMVN that reflects how we were normalizing the data when the - // user called the Freeze() function in class OnlineCmvn. - Matrix frozen_state; - - OnlineCmvnState() { } - - explicit OnlineCmvnState(const Matrix &global_stats): - global_cmvn_stats(global_stats) { } - - // Copy constructor - OnlineCmvnState(const OnlineCmvnState &other); - - void Write(std::ostream &os, bool binary) const; - void Read(std::istream &is, bool binary); - - // Use the default assignment operator. -}; - -/** - This class does an online version of the cepstral mean and [optionally] - variance, but note that this is not equivalent to the offline version. This - is necessarily so, as the offline computation involves looking into the - future. If you plan to use features normalized with this type of CMVN then - you need to train in a `matched' way, i.e. with the same type of features. - We normally only do so in the "online" GMM-based decoding, e.g. in - online2bin/online2-wav-gmm-latgen-faster.cc; see also the script - steps/online/prepare_online_decoding.sh and steps/online/decode.sh. - - In the steady state (in the middle of a long utterance), this class - accumulates CMVN statistics from the previous "cmn_window" frames (default 600 - frames, or 6 seconds), and uses these to normalize the mean and possibly - variance of the current frame. - - The config variables "speaker_frames" and "global_frames" relate to what - happens at the beginning of the utterance when we have seen fewer than - "cmn_window" frames of context, and so might not have very good stats to - normalize with. Basically, we first augment any existing stats with up - to "speaker_frames" frames of stats from previous utterances of the current - speaker, and if this doesn't take us up to the required "cmn_window" frame - count, we further augment with up to "global_frames" frames of global - stats. The global stats are CMVN stats accumulated from training or testing - data, that give us a reasonable source of mean and variance for "typical" - data. - */ -class OnlineCmvn: public OnlineFeatureInterface { - public: - - // - // First, functions that are present in the interface: - // - virtual int32 Dim() const { return src_->Dim(); } - - virtual bool IsLastFrame(int32 frame) const { - return src_->IsLastFrame(frame); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - // The online cmvn does not introduce any additional latency. - virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } - - virtual void GetFrame(int32 frame, VectorBase *feat); - - // - // Next, functions that are not in the interface. - // - - /// Initializer that sets the cmvn state. If you don't have previous - /// utterances from the same speaker you are supposed to initialize the CMVN - /// state from some global CMVN stats, which you can get from summing all cmvn - /// stats you have in your training data using "sum-matrix". This just gives - /// it a reasonable starting point at the start of the file. - /// If you do have previous utterances from the same speaker or at least a - /// similar environment, you are supposed to initialize it by calling GetState - /// from the previous utterance - OnlineCmvn(const OnlineCmvnOptions &opts, - const OnlineCmvnState &cmvn_state, - OnlineFeatureInterface *src); - - /// Initializer that does not set the cmvn state: - /// after calling this, you should call SetState(). - OnlineCmvn(const OnlineCmvnOptions &opts, - OnlineFeatureInterface *src); - - // Outputs any state information from this utterance to "cmvn_state". - // The value of "cmvn_state" before the call does not matter: the output - // depends on the value of OnlineCmvnState the class was initialized - // with, the input feature values up to cur_frame, and the effects - // of the user possibly having called Freeze(). - // If cur_frame is -1, it will just output the unmodified original - // state that was supplied to this object. - void GetState(int32 cur_frame, - OnlineCmvnState *cmvn_state); - - // This function can be used to modify the state of the CMVN computation - // from outside, but must only be called before you have processed any data - // (otherwise it will crash). This "state" is really just the information - // that is propagated between utterances, not the state of the computation - // inside an utterance. - void SetState(const OnlineCmvnState &cmvn_state); - - // From this point it will freeze the CMN to what it would have been if - // measured at frame "cur_frame", and it will stop it from changing - // further. This also applies retroactively for this utterance, so if you - // call GetFrame() on previous frames, it will use the CMVN stats - // from cur_frame; and it applies in the future too if you then - // call OutputState() and use this state to initialize the next - // utterance's CMVN object. - void Freeze(int32 cur_frame); - - virtual ~OnlineCmvn(); - private: - - /// Smooth the CMVN stats "stats" (which are stored in the normal format as a - /// 2 x (dim+1) matrix), by possibly adding some stats from "global_stats" - /// and/or "speaker_stats", controlled by the config. The best way to - /// understand the smoothing rule we use is just to look at the code. - static void SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, - const MatrixBase &global_stats, - const OnlineCmvnOptions &opts, - MatrixBase *stats); - - /// Get the most recent cached frame of CMVN stats. [If no frames - /// were cached, sets up empty stats for frame zero and returns that]. - void GetMostRecentCachedFrame(int32 frame, - int32 *cached_frame, - MatrixBase *stats); - - /// Cache this frame of stats. - void CacheFrame(int32 frame, const MatrixBase &stats); - - /// Initialize ring buffer for caching stats. - inline void InitRingBufferIfNeeded(); - - /// Computes the raw CMVN stats for this frame, making use of (and updating if - /// necessary) the cached statistics in raw_stats_. This means the (x, - /// x^2, count) stats for the last up to opts_.cmn_window frames. - void ComputeStatsForFrame(int32 frame, - MatrixBase *stats); - - - OnlineCmvnOptions opts_; - std::vector skip_dims_; // Skip CMVN for these dimensions. Derived from opts_. - OnlineCmvnState orig_state_; // reflects the state before we saw this - // utterance. - Matrix frozen_state_; // If the user called Freeze(), this variable - // will reflect the CMVN state that we froze - // at. - - // The variable below reflects the raw (count, x, x^2) statistics of the - // input, computed every opts_.modulus frames. raw_stats_[n / opts_.modulus] - // contains the (count, x, x^2) statistics for the frames from - // std::max(0, n - opts_.cmn_window) through n. - std::vector*> cached_stats_modulo_; - // the variable below is a ring-buffer of cached stats. the int32 is the - // frame index. - std::vector > > cached_stats_ring_; - - // Some temporary variables used inside functions of this class, which - // put here to avoid reallocation. - Matrix temp_stats_; - Vector temp_feats_; - Vector temp_feats_dbl_; - - OnlineFeatureInterface *src_; // Not owned here -}; - - -struct OnlineSpliceOptions { - int32 left_context; - int32 right_context; - OnlineSpliceOptions(): left_context(4), right_context(4) { } - void Register(ParseOptions *po) { - po->Register("left-context", &left_context, "Left-context for frame " - "splicing prior to LDA"); - po->Register("right-context", &right_context, "Right-context for frame " - "splicing prior to LDA"); - } -}; - -class OnlineSpliceFrames: public OnlineFeatureInterface { - public: - // - // First, functions that are present in the interface: - // - virtual int32 Dim() const { - return src_->Dim() * (1 + left_context_ + right_context_); - } - - virtual bool IsLastFrame(int32 frame) const { - return src_->IsLastFrame(frame); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const; - - virtual void GetFrame(int32 frame, VectorBase *feat); - - // - // Next, functions that are not in the interface. - // - OnlineSpliceFrames(const OnlineSpliceOptions &opts, - OnlineFeatureInterface *src): - left_context_(opts.left_context), right_context_(opts.right_context), - src_(src) { } - - private: - int32 left_context_; - int32 right_context_; - OnlineFeatureInterface *src_; // Not owned here -}; - -/// This online-feature class implements any affine or linear transform. -class OnlineTransform: public OnlineFeatureInterface { - public: - // - // First, functions that are present in the interface: - // - virtual int32 Dim() const { return offset_.Dim(); } - - virtual bool IsLastFrame(int32 frame) const { - return src_->IsLastFrame(frame); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } - - virtual void GetFrame(int32 frame, VectorBase *feat); - - virtual void GetFrames(const std::vector &frames, - MatrixBase *feats); - - // - // Next, functions that are not in the interface. - // - - /// The transform can be a linear transform, or an affine transform - /// where the last column is the offset. - OnlineTransform(const MatrixBase &transform, - OnlineFeatureInterface *src); - - - private: - OnlineFeatureInterface *src_; // Not owned here - Matrix linear_term_; - Vector offset_; -}; - -class OnlineDeltaFeature: public OnlineFeatureInterface { - public: - // - // First, functions that are present in the interface: - // - virtual int32 Dim() const; - - virtual bool IsLastFrame(int32 frame) const { - return src_->IsLastFrame(frame); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const; - - virtual void GetFrame(int32 frame, VectorBase *feat); - - // - // Next, functions that are not in the interface. - // - OnlineDeltaFeature(const DeltaFeaturesOptions &opts, - OnlineFeatureInterface *src); - - private: - OnlineFeatureInterface *src_; // Not owned here - DeltaFeaturesOptions opts_; - DeltaFeatures delta_features_; // This class contains just a few - // coefficients. -}; - - -/// This feature type can be used to cache its input, to avoid -/// repetition of computation in a multi-pass decoding context. -class OnlineCacheFeature: public OnlineFeatureInterface { - public: - virtual int32 Dim() const { return src_->Dim(); } - - virtual bool IsLastFrame(int32 frame) const { - return src_->IsLastFrame(frame); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } - - virtual void GetFrame(int32 frame, VectorBase *feat); - - virtual void GetFrames(const std::vector &frames, - MatrixBase *feats); - - virtual ~OnlineCacheFeature() { ClearCache(); } - - // Things that are not in the shared interface: - - void ClearCache(); // this should be called if you change the underlying - // features in some way. - - explicit OnlineCacheFeature(OnlineFeatureInterface *src): src_(src) { } - private: - - OnlineFeatureInterface *src_; // Not owned here - std::vector* > cache_; -}; - - - - -/// This online-feature class implements combination of two feature -/// streams (such as pitch, plp) into one stream. -class OnlineAppendFeature: public OnlineFeatureInterface { - public: - virtual int32 Dim() const { return src1_->Dim() + src2_->Dim(); } - - virtual bool IsLastFrame(int32 frame) const { - return (src1_->IsLastFrame(frame) || src2_->IsLastFrame(frame)); - } - // Hopefully sources have the same rate - virtual BaseFloat FrameShiftInSeconds() const { - return src1_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const { - return std::min(src1_->NumFramesReady(), src2_->NumFramesReady()); - } - - virtual void GetFrame(int32 frame, VectorBase *feat); - - virtual ~OnlineAppendFeature() { } - - OnlineAppendFeature(OnlineFeatureInterface *src1, - OnlineFeatureInterface *src2): src1_(src1), src2_(src2) { } - private: - - OnlineFeatureInterface *src1_; - OnlineFeatureInterface *src2_; -}; - -/// @} End of "addtogroup onlinefeat" -} // namespace kaldi - -#endif // KALDI_FEAT_ONLINE_FEATURE_H_ diff --git a/speechx/speechx/kaldi/feat/pitch-functions.cc b/speechx/speechx/kaldi/feat/pitch-functions.cc deleted file mode 100644 index 430e9bdb..00000000 --- a/speechx/speechx/kaldi/feat/pitch-functions.cc +++ /dev/null @@ -1,1667 +0,0 @@ -// feat/pitch-functions.cc - -// Copyright 2013 Pegah Ghahremani -// 2014 IMSL, PKU-HKUST (author: Wei Shi) -// 2014 Yanqing Sun, Junjie Wang, -// Daniel Povey, Korbinian Riedhammer -// Xin Lei - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "feat/feature-functions.h" -#include "feat/mel-computations.h" -#include "feat/online-feature.h" -#include "feat/pitch-functions.h" -#include "feat/resample.h" -#include "matrix/matrix-functions.h" - -namespace kaldi { - -/** - This function processes the NCCF n to a POV feature f by applying the formula - f = (1.0001 - n)^0.15 - 1.0 - This is a nonlinear function designed to make the output reasonably Gaussian - distributed. Before doing this, the NCCF distribution is in the range [-1, - 1] but has a strong peak just before 1.0, which this function smooths out. -*/ - -BaseFloat NccfToPovFeature(BaseFloat n) { - if (n > 1.0) { - n = 1.0; - } else if (n < -1.0) { - n = -1.0; - } - BaseFloat f = pow((1.0001 - n), 0.15) - 1.0; - KALDI_ASSERT(f - f == 0); // check for NaN,inf. - return f; -} - -/** - This function processes the NCCF n to a reasonably accurate probability - of voicing p by applying the formula: - - n' = fabs(n) - r = -5.2 + 5.4 * exp(7.5 * (n' - 1.0)) + - 4.8 * n' - 2.0 * exp(-10.0 * n') + 4.2 * exp(20.0 * (n' - 1.0)); - p = 1.0 / (1 + exp(-1.0 * r)); - - How did we get this formula? We plotted the empirical log-prob-ratio of voicing - r = log( p[voiced] / p[not-voiced] ) - [on the Keele database where voicing is marked], as a function of the NCCF at - the delay picked by our algorithm. This was done on intervals of the NCCF, so - we had enough statistics to get that ratio. The NCCF covers [-1, 1]; almost - all of the probability mass is on [0, 1] but the empirical POV seems fairly - symmetric with a minimum near zero, so we chose to make it a function of n' = fabs(n). - - Then we manually tuned a function (the one you see above) that approximated - the log-prob-ratio of voicing fairly well as a function of the absolute-value - NCCF n'; however, wasn't a very exact match since we were also trying to make - the transformed NCCF fairly Gaussian distributed, with a view to using it as - a feature-- an idea we later abandoned after a simpler formula worked better. - */ -BaseFloat NccfToPov(BaseFloat n) { - BaseFloat ndash = fabs(n); - if (ndash > 1.0) ndash = 1.0; // just in case it was slightly outside [-1, 1] - - BaseFloat r = -5.2 + 5.4 * Exp(7.5 * (ndash - 1.0)) + 4.8 * ndash - - 2.0 * Exp(-10.0 * ndash) + 4.2 * Exp(20.0 * (ndash - 1.0)); - // r is the approximate log-prob-ratio of voicing, log(p/(1-p)). - BaseFloat p = 1.0 / (1 + Exp(-1.0 * r)); - KALDI_ASSERT(p - p == 0); // Check for NaN/inf - return p; -} - -/** - This function computes some dot products that are required - while computing the NCCF. - For each integer lag from start to end-1, this function - outputs to (*inner_prod)(lag - start), the dot-product - of a window starting at 0 with a window starting at - lag. All windows are of length nccf_window_size. It - outputs to (*norm_prod)(lag - start), e1 * e2, where - e1 is the dot-product of the un-shifted window with itself, - and d2 is the dot-product of the window shifted by "lag" - with itself. - */ -void ComputeCorrelation(const VectorBase &wave, - int32 first_lag, int32 last_lag, - int32 nccf_window_size, - VectorBase *inner_prod, - VectorBase *norm_prod) { - Vector zero_mean_wave(wave); - // TODO: possibly fix this, the mean normalization is done in a strange way. - SubVector wave_part(wave, 0, nccf_window_size); - // subtract mean-frame from wave - zero_mean_wave.Add(-wave_part.Sum() / nccf_window_size); - BaseFloat e1, e2, sum; - SubVector sub_vec1(zero_mean_wave, 0, nccf_window_size); - e1 = VecVec(sub_vec1, sub_vec1); - for (int32 lag = first_lag; lag <= last_lag; lag++) { - SubVector sub_vec2(zero_mean_wave, lag, nccf_window_size); - e2 = VecVec(sub_vec2, sub_vec2); - sum = VecVec(sub_vec1, sub_vec2); - (*inner_prod)(lag - first_lag) = sum; - (*norm_prod)(lag - first_lag) = e1 * e2; - } -} - -/** - Computes the NCCF as a fraction of the numerator term (a dot product between - two vectors) and a denominator term which equals sqrt(e1*e2 + nccf_ballast) - where e1 and e2 are both dot-products of bits of the wave with themselves, - and e1*e2 is supplied as "norm_prod". These quantities are computed by - "ComputeCorrelation". -*/ -void ComputeNccf(const VectorBase &inner_prod, - const VectorBase &norm_prod, - BaseFloat nccf_ballast, - VectorBase *nccf_vec) { - KALDI_ASSERT(inner_prod.Dim() == norm_prod.Dim() && - inner_prod.Dim() == nccf_vec->Dim()); - for (int32 lag = 0; lag < inner_prod.Dim(); lag++) { - BaseFloat numerator = inner_prod(lag), - denominator = pow(norm_prod(lag) + nccf_ballast, 0.5), - nccf; - if (denominator != 0.0) { - nccf = numerator / denominator; - } else { - KALDI_ASSERT(numerator == 0.0); - nccf = 0.0; - } - KALDI_ASSERT(nccf < 1.01 && nccf > -1.01); - (*nccf_vec)(lag) = nccf; - } -} - -/** - This function selects the lags at which we measure the NCCF: we need - to select lags from 1/max_f0 to 1/min_f0, in a geometric progression - with ratio 1 + d. - */ -void SelectLags(const PitchExtractionOptions &opts, - Vector *lags) { - // choose lags relative to acceptable pitch tolerance - BaseFloat min_lag = 1.0 / opts.max_f0, max_lag = 1.0 / opts.min_f0; - - std::vector tmp_lags; - for (BaseFloat lag = min_lag; lag <= max_lag; lag *= 1.0 + opts.delta_pitch) - tmp_lags.push_back(lag); - lags->Resize(tmp_lags.size()); - std::copy(tmp_lags.begin(), tmp_lags.end(), lags->Data()); -} - - -/** - This function computes the local-cost for the Viterbi computation, - see eq. (5) in the paper. - @param opts The options as provided by the user - @param nccf_pitch The nccf as computed for the pitch computation (with ballast). - @param lags The log-spaced lags at which nccf_pitch is sampled. - @param local_cost We output the local-cost to here. -*/ -void ComputeLocalCost(const VectorBase &nccf_pitch, - const VectorBase &lags, - const PitchExtractionOptions &opts, - VectorBase *local_cost) { - // from the paper, eq. 5, local_cost = 1 - Phi(t,i)(1 - soft_min_f0 L_i) - // nccf is the nccf on this frame measured at the lags in "lags". - KALDI_ASSERT(nccf_pitch.Dim() == local_cost->Dim() && - nccf_pitch.Dim() == lags.Dim()); - local_cost->Set(1.0); - // add the term -Phi(t,i): - local_cost->AddVec(-1.0, nccf_pitch); - // add the term soft_min_f0 Phi(t,i) L_i - local_cost->AddVecVec(opts.soft_min_f0, lags, nccf_pitch, 1.0); -} - - - -// class PitchFrameInfo is used inside class OnlinePitchFeatureImpl. -// It stores the information we need to keep around for a single frame -// of the pitch computation. -class PitchFrameInfo { - public: - /// This function resizes the arrays for this object and updates the reference - /// counts for the previous object (by decrementing those reference counts - /// when we destroy a StateInfo object). A StateInfo object is considered to - /// be destroyed when we delete it, not when its reference counts goes to - /// zero. - void Cleanup(PitchFrameInfo *prev_frame); - - /// This function may be called for the last (most recent) PitchFrameInfo - /// object with the best state (obtained from the externally held - /// forward-costs). It traces back as far as needed to set the - /// cur_best_state_, and as it's going it sets the lag-index and pov_nccf in - /// pitch_pov_iter, which when it's called is an iterator to where to put the - /// info for the final state; the iterator will be decremented inside this - /// function. - void SetBestState(int32 best_state, - std::vector > &lag_nccf); - - /// This function may be called on the last (most recent) PitchFrameInfo - /// object; it computes how many frames of latency there is because the - /// traceback has not yet settled on a single value for frames in the past. - /// It actually returns the minimum of max_latency and the actual latency, - /// which is an optimization because we won't care about latency past - /// a user-specified maximum latency. - int32 ComputeLatency(int32 max_latency); - - /// This function updates - bool UpdatePreviousBestState(PitchFrameInfo *prev_frame); - - /// This constructor is used for frame -1; it sets the costs to be all zeros - /// the pov_nccf's to zero and the backpointers to -1. - explicit PitchFrameInfo(int32 num_states); - - /// This constructor is used for subsequent frames (not -1). - PitchFrameInfo(PitchFrameInfo *prev); - - /// Record the nccf_pov value. - /// @param nccf_pov The nccf as computed for the POV computation (without ballast). - void SetNccfPov(const VectorBase &nccf_pov); - - /// This constructor is used for frames apart from frame -1; the bulk of - /// the Viterbi computation takes place inside this constructor. - /// @param opts The options as provided by the user - /// @param nccf_pitch The nccf as computed for the pitch computation - /// (with ballast). - /// @param nccf_pov The nccf as computed for the POV computation - /// (without ballast). - /// @param lags The log-spaced lags at which nccf_pitch and - /// nccf_pov are sampled. - /// @param prev_frame_forward_cost The forward-cost vector for the - /// previous frame. - /// @param index_info A pointer to a temporary vector used by this function - /// @param this_forward_cost The forward-cost vector for this frame - /// (to be computed). - void ComputeBacktraces(const PitchExtractionOptions &opts, - const VectorBase &nccf_pitch, - const VectorBase &lags, - const VectorBase &prev_forward_cost, - std::vector > *index_info, - VectorBase *this_forward_cost); - private: - // struct StateInfo is the information we keep for a single one of the - // log-spaced lags, for a single frame. This is a state in the Viterbi - // computation. - struct StateInfo { - /// The state index on the previous frame that is the best preceding state - /// for this state. - int32 backpointer; - /// the version of the NCCF we keep for the POV computation (without the - /// ballast term). - BaseFloat pov_nccf; - StateInfo(): backpointer(0), pov_nccf(0.0) { } - }; - std::vector state_info_; - /// the state index of the first entry in "state_info"; this will initially be - /// zero, but after cleanup might be nonzero. - int32 state_offset_; - - /// The current best state in the backtrace from the end. - int32 cur_best_state_; - - /// The structure for the previous frame. - PitchFrameInfo *prev_info_; -}; - - -// This constructor is used for frame -1; it sets the costs to be all zeros -// the pov_nccf's to zero and the backpointers to -1. -PitchFrameInfo::PitchFrameInfo(int32 num_states) - :state_info_(num_states), state_offset_(0), - cur_best_state_(-1), prev_info_(NULL) { } - - -bool pitch_use_naive_search = false; // This is used in unit-tests. - - -PitchFrameInfo::PitchFrameInfo(PitchFrameInfo *prev_info): - state_info_(prev_info->state_info_.size()), state_offset_(0), - cur_best_state_(-1), prev_info_(prev_info) { } - -void PitchFrameInfo::SetNccfPov(const VectorBase &nccf_pov) { - int32 num_states = nccf_pov.Dim(); - KALDI_ASSERT(num_states == state_info_.size()); - for (int32 i = 0; i < num_states; i++) - state_info_[i].pov_nccf = nccf_pov(i); -} - -void PitchFrameInfo::ComputeBacktraces( - const PitchExtractionOptions &opts, - const VectorBase &nccf_pitch, - const VectorBase &lags, - const VectorBase &prev_forward_cost_vec, - std::vector > *index_info, - VectorBase *this_forward_cost_vec) { - int32 num_states = nccf_pitch.Dim(); - - Vector local_cost(num_states, kUndefined); - ComputeLocalCost(nccf_pitch, lags, opts, &local_cost); - - const BaseFloat delta_pitch_sq = pow(Log(1.0 + opts.delta_pitch), 2.0), - inter_frame_factor = delta_pitch_sq * opts.penalty_factor; - - // index local_cost, prev_forward_cost and this_forward_cost using raw pointer - // indexing not operator (), since this is the very inner loop and a lot of - // time is taken here. - const BaseFloat *prev_forward_cost = prev_forward_cost_vec.Data(); - BaseFloat *this_forward_cost = this_forward_cost_vec->Data(); - - if (index_info->empty()) - index_info->resize(num_states); - - // make it a reference for more concise indexing. - std::vector > &bounds = *index_info; - - /* bounds[i].first will be a lower bound on the backpointer for state i, - bounds[i].second will be an upper bound on it. We progressively tighten - these bounds till we know the backpointers exactly. - */ - - if (pitch_use_naive_search) { - // This branch is only taken in unit-testing code. - for (int32 i = 0; i < num_states; i++) { - BaseFloat best_cost = std::numeric_limits::infinity(); - int32 best_j = -1; - for (int32 j = 0; j < num_states; j++) { - BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor - + prev_forward_cost[j]; - if (this_cost < best_cost) { - best_cost = this_cost; - best_j = j; - } - } - this_forward_cost[i] = best_cost; - state_info_[i].backpointer = best_j; - } - } else { - int32 last_backpointer = 0; - for (int32 i = 0; i < num_states; i++) { - int32 start_j = last_backpointer; - BaseFloat best_cost = (start_j - i) * (start_j - i) * inter_frame_factor - + prev_forward_cost[start_j]; - int32 best_j = start_j; - - for (int32 j = start_j + 1; j < num_states; j++) { - BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor - + prev_forward_cost[j]; - if (this_cost < best_cost) { - best_cost = this_cost; - best_j = j; - } else { // as soon as the costs stop improving, we stop searching. - break; // this is a loose lower bound we're getting. - } - } - state_info_[i].backpointer = best_j; - this_forward_cost[i] = best_cost; - bounds[i].first = best_j; // this is now a lower bound on the - // backpointer. - bounds[i].second = num_states - 1; // we have no meaningful upper bound - // yet. - last_backpointer = best_j; - } - - // We iterate, progressively refining the upper and lower bounds until they - // meet and we know that the resulting backtraces are optimal. Each - // iteration takes time linear in num_states. We won't normally iterate as - // far as num_states; normally we only do two iterations; when printing out - // the number of iterations, it's rarely more than that (once I saw seven - // iterations). Anyway, this part of the computation does not dominate. - for (int32 iter = 0; iter < num_states; iter++) { - bool changed = false; - if (iter % 2 == 0) { // go backwards through the states - last_backpointer = num_states - 1; - for (int32 i = num_states - 1; i >= 0; i--) { - int32 lower_bound = bounds[i].first, - upper_bound = std::min(last_backpointer, bounds[i].second); - if (upper_bound == lower_bound) { - last_backpointer = lower_bound; - continue; - } - BaseFloat best_cost = this_forward_cost[i]; - int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; - - if (best_j == upper_bound) { - // if best_j already equals upper bound, don't bother tightening the - // upper bound, we'll tighten the lower bound when the time comes. - last_backpointer = best_j; - continue; - } - // Below, we have j > lower_bound + 1 because we know we've already - // evaluated lower_bound and lower_bound + 1 [via knowledge of - // this algorithm.] - for (int32 j = upper_bound; j > lower_bound + 1; j--) { - BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor - + prev_forward_cost[j]; - if (this_cost < best_cost) { - best_cost = this_cost; - best_j = j; - } else { // as soon as the costs stop improving, we stop searching, - // unless the best j is still lower than j, in which case - // we obviously need to keep moving. - if (best_j > j) - break; // this is a loose lower bound we're getting. - } - } - // our "best_j" is now an upper bound on the backpointer. - bounds[i].second = best_j; - if (best_j != initial_best_j) { - this_forward_cost[i] = best_cost; - state_info_[i].backpointer = best_j; - changed = true; - } - last_backpointer = best_j; - } - } else { // go forwards through the states. - last_backpointer = 0; - for (int32 i = 0; i < num_states; i++) { - int32 lower_bound = std::max(last_backpointer, bounds[i].first), - upper_bound = bounds[i].second; - if (upper_bound == lower_bound) { - last_backpointer = lower_bound; - continue; - } - BaseFloat best_cost = this_forward_cost[i]; - int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; - - if (best_j == lower_bound) { - // if best_j already equals lower bound, we don't bother tightening - // the lower bound, we'll tighten the upper bound when the time - // comes. - last_backpointer = best_j; - continue; - } - // Below, we have j < upper_bound because we know we've already - // evaluated that point. - for (int32 j = lower_bound; j < upper_bound - 1; j++) { - BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor - + prev_forward_cost[j]; - if (this_cost < best_cost) { - best_cost = this_cost; - best_j = j; - } else { // as soon as the costs stop improving, we stop searching, - // unless the best j is still higher than j, in which case - // we obviously need to keep moving. - if (best_j < j) - break; // this is a loose lower bound we're getting. - } - } - // our "best_j" is now a lower bound on the backpointer. - bounds[i].first = best_j; - if (best_j != initial_best_j) { - this_forward_cost[i] = best_cost; - state_info_[i].backpointer = best_j; - changed = true; - } - last_backpointer = best_j; - } - } - if (!changed) - break; - } - } - // The next statement is needed due to RecomputeBacktraces: we have to - // invalidate the previously computed best-state info. - cur_best_state_ = -1; - this_forward_cost_vec->AddVec(1.0, local_cost); -} - -void PitchFrameInfo::SetBestState( - int32 best_state, - std::vector > &lag_nccf) { - - // This function would naturally be recursive, but we have coded this to avoid - // recursion, which would otherwise eat up the stack. Think of it as a static - // member function, except we do use "this" right at the beginning. - - std::vector >::reverse_iterator iter = lag_nccf.rbegin(); - - PitchFrameInfo *this_info = this; // it will change in the loop. - while (this_info != NULL) { - PitchFrameInfo *prev_info = this_info->prev_info_; - if (best_state == this_info->cur_best_state_) - return; // no change - if (prev_info != NULL) // don't write anything for frame -1. - iter->first = best_state; - size_t state_info_index = best_state - this_info->state_offset_; - KALDI_ASSERT(state_info_index < this_info->state_info_.size()); - this_info->cur_best_state_ = best_state; - best_state = this_info->state_info_[state_info_index].backpointer; - if (prev_info != NULL) // don't write anything for frame -1. - iter->second = this_info->state_info_[state_info_index].pov_nccf; - this_info = prev_info; - if (this_info != NULL) ++iter; - } -} - -int32 PitchFrameInfo::ComputeLatency(int32 max_latency) { - if (max_latency <= 0) return 0; - - int32 latency = 0; - - // This function would naturally be recursive, but we have coded this to avoid - // recursion, which would otherwise eat up the stack. Think of it as a static - // member function, except we do use "this" right at the beginning. - // This function is called only on the most recent PitchFrameInfo object. - int32 num_states = state_info_.size(); - int32 min_living_state = 0, max_living_state = num_states - 1; - PitchFrameInfo *this_info = this; // it will change in the loop. - - - for (; this_info != NULL && latency < max_latency;) { - int32 offset = this_info->state_offset_; - KALDI_ASSERT(min_living_state >= offset && - max_living_state - offset < this_info->state_info_.size()); - min_living_state = - this_info->state_info_[min_living_state - offset].backpointer; - max_living_state = - this_info->state_info_[max_living_state - offset].backpointer; - if (min_living_state == max_living_state) { - return latency; - } - this_info = this_info->prev_info_; - if (this_info != NULL) // avoid incrementing latency for frame -1, - latency++; // as it's not a real frame. - } - return latency; -} - -void PitchFrameInfo::Cleanup(PitchFrameInfo *prev_frame) { - KALDI_ERR << "Cleanup not implemented."; -} - - -// struct NccfInfo is used to cache certain quantities that we need for online -// operation, for the first "recompute_frame" frames of the file (e.g. 300); -// after that many frames, or after the user calls InputFinished(), we redo the -// initial backtraces, as we'll then have a better estimate of the average signal -// energy. -struct NccfInfo { - - Vector nccf_pitch_resampled; // resampled nccf_pitch - BaseFloat avg_norm_prod; // average value of e1 * e2. - BaseFloat mean_square_energy; // mean_square energy we used when computing the - // original ballast term for - // "nccf_pitch_resampled". - - NccfInfo(BaseFloat avg_norm_prod, - BaseFloat mean_square_energy): - avg_norm_prod(avg_norm_prod), - mean_square_energy(mean_square_energy) { } -}; - - - -// We could inherit from OnlineBaseFeature as we have the same interface, -// but this will unnecessary force a lot of our functions to be virtual. -class OnlinePitchFeatureImpl { - public: - explicit OnlinePitchFeatureImpl(const PitchExtractionOptions &opts); - - int32 Dim() const { return 2; } - - BaseFloat FrameShiftInSeconds() const; - - int32 NumFramesReady() const; - - bool IsLastFrame(int32 frame) const; - - void GetFrame(int32 frame, VectorBase *feat); - - void AcceptWaveform(BaseFloat sampling_rate, - const VectorBase &waveform); - - void InputFinished(); - - ~OnlinePitchFeatureImpl(); - - - // Copy-constructor, can be used to obtain a new copy of this object, - // any state from this utterance. - OnlinePitchFeatureImpl(const OnlinePitchFeatureImpl &other); - - private: - - /// This function works out from the signal how many frames are currently - /// available to process (this is called from inside AcceptWaveform()). - /// Note: the number of frames differs slightly from the number the - /// old pitch code gave. - /// Note: the number this returns depends on whether input_finished_ == true; - /// if it is, it will "force out" a final frame or two. - int32 NumFramesAvailable(int64 num_downsampled_samples, bool snip_edges) const; - - /// This function extracts from the signal the samples numbered from - /// "sample_index" (numbered in the full downsampled signal, not just this - /// part), and of length equal to window->Dim(). It uses the data members - /// downsampled_samples_discarded_ and downsampled_signal_remainder_, as well - /// as the more recent part of the downsampled wave "downsampled_wave_part" - /// which is provided. - /// - /// @param downsampled_wave_part One chunk of the downsampled wave, - /// starting from sample-index downsampled_samples_discarded_. - /// @param sample_index The desired starting sample index (measured from - /// the start of the whole signal, not just this part). - /// @param window The part of the signal is output to here. - void ExtractFrame(const VectorBase &downsampled_wave_part, - int64 frame_index, - VectorBase *window); - - - /// This function is called after we reach frame "recompute_frame", or when - /// InputFinished() is called, whichever comes sooner. It recomputes the - /// backtraces for frames zero through recompute_frame, if needed because the - /// average energy of the signal has changed, affecting the nccf ballast term. - /// It works out the average signal energy from - /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_ (which, if - /// you see the calling code, might include more frames than just - /// "recompute_frame", it might include up to the end of the current chunk). - void RecomputeBacktraces(); - - - /// This function updates downsampled_signal_remainder_, - /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_; it's called - /// from AcceptWaveform(). - void UpdateRemainder(const VectorBase &downsampled_wave_part); - - - // The following variables don't change throughout the lifetime - // of this object. - PitchExtractionOptions opts_; - - // the first lag of the downsampled signal at which we measure NCCF - int32 nccf_first_lag_; - // the last lag of the downsampled signal at which we measure NCCF - int32 nccf_last_lag_; - - // The log-spaced lags at which we will resample the NCCF - Vector lags_; - - // This object is used to resample from evenly spaced to log-evenly-spaced - // nccf values. It's a pointer for convenience of initialization, so we don't - // have to use the initializer from the constructor. - ArbitraryResample *nccf_resampler_; - - // The following objects may change during the lifetime of this object. - - // This object is used to resample the signal. - LinearResample *signal_resampler_; - - // frame_info_ is indexed by [frame-index + 1]. frame_info_[0] is an object - // that corresponds to frame -1, which is not a real frame. - std::vector frame_info_; - - - // nccf_info_ is indexed by frame-index, from frame 0 to at most - // opts_.recompute_frame - 1. It contains some information we'll - // need to recompute the tracebacks after getting a better estimate - // of the average energy of the signal. - std::vector nccf_info_; - - // Current number of frames which we can't output because Viterbi has not - // converged for them, or opts_.max_frames_latency if we have reached that - // limit. - int32 frames_latency_; - - // The forward-cost at the current frame (the last frame in frame_info_); - // this has the same dimension as lags_. We normalize each time so - // the lowest cost is zero, for numerical accuracy and so we can use float. - Vector forward_cost_; - - // stores the constant part of forward_cost_. - double forward_cost_remainder_; - - // The resampled-lag index and the NCCF (as computed for POV, without ballast - // term) for each frame, as determined by Viterbi traceback from the best - // final state. - std::vector > lag_nccf_; - - bool input_finished_; - - /// sum-squared of previously processed parts of signal; used to get NCCF - /// ballast term. Denominator is downsampled_samples_processed_. - double signal_sumsq_; - - /// sum of previously processed parts of signal; used to do mean-subtraction - /// when getting sum-squared, along with signal_sumsq_. - double signal_sum_; - - /// downsampled_samples_processed is the number of samples (after - /// downsampling) that we got in previous calls to AcceptWaveform(). - int64 downsampled_samples_processed_; - /// This is a small remainder of the previous downsampled signal; - /// it's used by ExtractFrame for frames near the boundary of two - /// waveforms supplied to AcceptWaveform(). - Vector downsampled_signal_remainder_; -}; - - -OnlinePitchFeatureImpl::OnlinePitchFeatureImpl( - const PitchExtractionOptions &opts): - opts_(opts), forward_cost_remainder_(0.0), input_finished_(false), - signal_sumsq_(0.0), signal_sum_(0.0), downsampled_samples_processed_(0) { - signal_resampler_ = new LinearResample(opts.samp_freq, opts.resample_freq, - opts.lowpass_cutoff, - opts.lowpass_filter_width); - - double outer_min_lag = 1.0 / opts.max_f0 - - (opts.upsample_filter_width/(2.0 * opts.resample_freq)); - double outer_max_lag = 1.0 / opts.min_f0 + - (opts.upsample_filter_width/(2.0 * opts.resample_freq)); - nccf_first_lag_ = ceil(opts.resample_freq * outer_min_lag); - nccf_last_lag_ = floor(opts.resample_freq * outer_max_lag); - - frames_latency_ = 0; // will be set in AcceptWaveform() - - // Choose the lags at which we resample the NCCF. - SelectLags(opts, &lags_); - - // upsample_cutoff is the filter cutoff for upsampling the NCCF, which is the - // Nyquist of the resampling frequency. The NCCF is (almost completely) - // bandlimited to around "lowpass_cutoff" (1000 by default), and when the - // spectrum of this bandlimited signal is convolved with the spectrum of an - // impulse train with frequency "resample_freq", which are separated by 4kHz, - // we get energy at -5000,-3000, -1000...1000, 3000..5000, etc. Filtering at - // half the Nyquist (2000 by default) is sufficient to get only the first - // repetition. - BaseFloat upsample_cutoff = opts.resample_freq * 0.5; - - - Vector lags_offset(lags_); - // lags_offset equals lags_ (which are the log-spaced lag values we want to - // measure the NCCF at) with nccf_first_lag_ / opts.resample_freq subtracted - // from each element, so we can treat the measured NCCF values as as starting - // from sample zero in a signal that starts at the point start / - // opts.resample_freq. This is necessary because the ArbitraryResample code - // assumes that the input signal starts from sample zero. - lags_offset.Add(-nccf_first_lag_ / opts.resample_freq); - - int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_; - - nccf_resampler_ = new ArbitraryResample(num_measured_lags, opts.resample_freq, - upsample_cutoff, lags_offset, - opts.upsample_filter_width); - - // add a PitchInfo object for frame -1 (not a real frame). - frame_info_.push_back(new PitchFrameInfo(lags_.Dim())); - // zeroes forward_cost_; this is what we want for the fake frame -1. - forward_cost_.Resize(lags_.Dim()); -} - - -int32 OnlinePitchFeatureImpl::NumFramesAvailable( - int64 num_downsampled_samples, bool snip_edges) const { - int32 frame_shift = opts_.NccfWindowShift(), - frame_length = opts_.NccfWindowSize(); - // Use the "full frame length" to compute the number - // of frames only if the input is not finished. - if (!input_finished_) - frame_length += nccf_last_lag_; - if (num_downsampled_samples < frame_length) { - return 0; - } else { - if (!snip_edges) { - if (input_finished_) { - return static_cast(num_downsampled_samples * 1.0f / - frame_shift + 0.5f); - } else { - return static_cast((num_downsampled_samples - frame_length / 2) * - 1.0f / frame_shift + 0.5f); - } - } else { - return static_cast((num_downsampled_samples - frame_length) / - frame_shift + 1); - } - } -} - -void OnlinePitchFeatureImpl::UpdateRemainder( - const VectorBase &downsampled_wave_part) { - // frame_info_ has an extra element at frame-1, so subtract - // one from the length. - int64 num_frames = static_cast(frame_info_.size()) - 1, - next_frame = num_frames, - frame_shift = opts_.NccfWindowShift(), - next_frame_sample = frame_shift * next_frame; - - signal_sumsq_ += VecVec(downsampled_wave_part, downsampled_wave_part); - signal_sum_ += downsampled_wave_part.Sum(); - - // next_frame_sample is the first sample index we'll need for the - // next frame. - int64 next_downsampled_samples_processed = - downsampled_samples_processed_ + downsampled_wave_part.Dim(); - - if (next_frame_sample > next_downsampled_samples_processed) { - // this could only happen in the weird situation that the full frame length - // is less than the frame shift. - int32 full_frame_length = opts_.NccfWindowSize() + nccf_last_lag_; - KALDI_ASSERT(full_frame_length < frame_shift && "Code error"); - downsampled_signal_remainder_.Resize(0); - } else { - Vector new_remainder(next_downsampled_samples_processed - - next_frame_sample); - // note: next_frame_sample is the index into the entire signal, of - // new_remainder(0). - // i is the absolute index of the signal. - for (int64 i = next_frame_sample; - i < next_downsampled_samples_processed; i++) { - if (i >= downsampled_samples_processed_) { // in current signal. - new_remainder(i - next_frame_sample) = - downsampled_wave_part(i - downsampled_samples_processed_); - } else { // in old remainder; only reach here if waveform supplied is - new_remainder(i - next_frame_sample) = // tiny. - downsampled_signal_remainder_(i - downsampled_samples_processed_ + - downsampled_signal_remainder_.Dim()); - } - } - downsampled_signal_remainder_.Swap(&new_remainder); - } - downsampled_samples_processed_ = next_downsampled_samples_processed; -} - -void OnlinePitchFeatureImpl::ExtractFrame( - const VectorBase &downsampled_wave_part, - int64 sample_index, - VectorBase *window) { - int32 full_frame_length = window->Dim(); - int32 offset = static_cast(sample_index - - downsampled_samples_processed_); - - // Treat edge cases first - if (sample_index < 0) { - // Part of the frame is before the beginning of the signal. This - // should only happen if opts_.snip_edges == false, when we are - // processing the first few frames of signal. In this case - // we pad with zeros. - KALDI_ASSERT(opts_.snip_edges == false); - int32 sub_frame_length = sample_index + full_frame_length; - int32 sub_frame_index = full_frame_length - sub_frame_length; - KALDI_ASSERT(sub_frame_length > 0 && sub_frame_index > 0); - window->SetZero(); - SubVector sub_window(*window, sub_frame_index, sub_frame_length); - ExtractFrame(downsampled_wave_part, 0, &sub_window); - return; - } - - if (offset + full_frame_length > downsampled_wave_part.Dim()) { - // Requested frame is past end of the signal. This should only happen if - // input_finished_ == true, when we're flushing out the last couple of - // frames of signal. In this case we pad with zeros. - KALDI_ASSERT(input_finished_); - int32 sub_frame_length = downsampled_wave_part.Dim() - offset; - KALDI_ASSERT(sub_frame_length > 0); - window->SetZero(); - SubVector sub_window(*window, 0, sub_frame_length); - ExtractFrame(downsampled_wave_part, sample_index, &sub_window); - return; - } - - // "offset" is the offset of the start of the frame, into this - // signal. - if (offset >= 0) { - // frame is full inside the new part of the signal. - window->CopyFromVec(downsampled_wave_part.Range(offset, full_frame_length)); - } else { - // frame is partly in the remainder and partly in the new part. - int32 remainder_offset = downsampled_signal_remainder_.Dim() + offset; - KALDI_ASSERT(remainder_offset >= 0); // or we didn't keep enough remainder. - KALDI_ASSERT(offset + full_frame_length > 0); // or we should have - // processed this frame last - // time. - - int32 old_length = -offset, new_length = offset + full_frame_length; - window->Range(0, old_length).CopyFromVec( - downsampled_signal_remainder_.Range(remainder_offset, old_length)); - window->Range(old_length, new_length).CopyFromVec( - downsampled_wave_part.Range(0, new_length)); - } - if (opts_.preemph_coeff != 0.0) { - BaseFloat preemph_coeff = opts_.preemph_coeff; - for (int32 i = window->Dim() - 1; i > 0; i--) - (*window)(i) -= preemph_coeff * (*window)(i-1); - (*window)(0) *= (1.0 - preemph_coeff); - } -} - -bool OnlinePitchFeatureImpl::IsLastFrame(int32 frame) const { - int32 T = NumFramesReady(); - KALDI_ASSERT(frame < T); - return (input_finished_ && frame + 1 == T); -} - -BaseFloat OnlinePitchFeatureImpl::FrameShiftInSeconds() const { - return opts_.frame_shift_ms / 1000.0f; -} - -int32 OnlinePitchFeatureImpl::NumFramesReady() const { - int32 num_frames = lag_nccf_.size(), - latency = frames_latency_; - KALDI_ASSERT(latency <= num_frames); - return num_frames - latency; -} - - -void OnlinePitchFeatureImpl::GetFrame(int32 frame, - VectorBase *feat) { - KALDI_ASSERT(frame < NumFramesReady() && feat->Dim() == 2); - (*feat)(0) = lag_nccf_[frame].second; - (*feat)(1) = 1.0 / lags_(lag_nccf_[frame].first); -} - -void OnlinePitchFeatureImpl::InputFinished() { - input_finished_ = true; - // Process an empty waveform; this has an effect because - // after setting input_finished_ to true, NumFramesAvailable() - // will return a slightly larger number. - AcceptWaveform(opts_.samp_freq, Vector()); - int32 num_frames = static_cast(frame_info_.size() - 1); - if (num_frames < opts_.recompute_frame && !opts_.nccf_ballast_online) - RecomputeBacktraces(); - frames_latency_ = 0; - KALDI_VLOG(3) << "Pitch-tracking Viterbi cost is " - << (forward_cost_remainder_ / num_frames) - << " per frame, over " << num_frames << " frames."; -} - -// see comment with declaration. This is only relevant for online -// operation (it gets called for non-online mode, but is a no-op). -void OnlinePitchFeatureImpl::RecomputeBacktraces() { - KALDI_ASSERT(!opts_.nccf_ballast_online); - int32 num_frames = static_cast(frame_info_.size()) - 1; - - // The assertion reflects how we believe this function will be called. - KALDI_ASSERT(num_frames <= opts_.recompute_frame); - KALDI_ASSERT(nccf_info_.size() == static_cast(num_frames)); - if (num_frames == 0) - return; - double num_samp = downsampled_samples_processed_, sum = signal_sum_, - sumsq = signal_sumsq_, mean = sum / num_samp; - BaseFloat mean_square = sumsq / num_samp - mean * mean; - - bool must_recompute = false; - BaseFloat threshold = 0.01; - for (int32 frame = 0; frame < num_frames; frame++) - if (!ApproxEqual(nccf_info_[frame]->mean_square_energy, - mean_square, threshold)) - must_recompute = true; - - if (!must_recompute) { - // Nothing to do. We'll reach here, for instance, if everything was in one - // chunk and opts_.nccf_ballast_online == false. This is the case for - // offline processing. - for (size_t i = 0; i < nccf_info_.size(); i++) - delete nccf_info_[i]; - nccf_info_.clear(); - return; - } - - int32 num_states = forward_cost_.Dim(), - basic_frame_length = opts_.NccfWindowSize(); - - BaseFloat new_nccf_ballast = pow(mean_square * basic_frame_length, 2) * - opts_.nccf_ballast; - - double forward_cost_remainder = 0.0; - Vector forward_cost(num_states), // start off at zero. - next_forward_cost(forward_cost); - std::vector > index_info; - - for (int32 frame = 0; frame < num_frames; frame++) { - NccfInfo &nccf_info = *nccf_info_[frame]; - BaseFloat old_mean_square = nccf_info_[frame]->mean_square_energy, - avg_norm_prod = nccf_info_[frame]->avg_norm_prod, - old_nccf_ballast = pow(old_mean_square * basic_frame_length, 2) * - opts_.nccf_ballast, - nccf_scale = pow((old_nccf_ballast + avg_norm_prod) / - (new_nccf_ballast + avg_norm_prod), - static_cast(0.5)); - // The "nccf_scale" is an estimate of the scaling factor by which the NCCF - // would change on this frame, on average, by changing the ballast term from - // "old_nccf_ballast" to "new_nccf_ballast". It's not exact because the - // "avg_norm_prod" is just an average of the product e1 * e2 of frame - // energies of the (frame, shifted-frame), but these won't change that much - // within a frame, and even if they do, the inaccuracy of the scaled NCCF - // will still be very small if the ballast term didn't change much, or if - // it's much larger or smaller than e1*e2. By doing it as a simple scaling, - // we save the overhead of the NCCF resampling, which is a considerable part - // of the whole computation. - nccf_info.nccf_pitch_resampled.Scale(nccf_scale); - - frame_info_[frame + 1]->ComputeBacktraces( - opts_, nccf_info.nccf_pitch_resampled, lags_, - forward_cost, &index_info, &next_forward_cost); - - forward_cost.Swap(&next_forward_cost); - BaseFloat remainder = forward_cost.Min(); - forward_cost_remainder += remainder; - forward_cost.Add(-remainder); - } - KALDI_VLOG(3) << "Forward-cost per frame changed from " - << (forward_cost_remainder_ / num_frames) << " to " - << (forward_cost_remainder / num_frames); - - forward_cost_remainder_ = forward_cost_remainder; - forward_cost_.Swap(&forward_cost); - - int32 best_final_state; - forward_cost_.Min(&best_final_state); - - if (lag_nccf_.size() != static_cast(num_frames)) - lag_nccf_.resize(num_frames); - - frame_info_.back()->SetBestState(best_final_state, lag_nccf_); - frames_latency_ = - frame_info_.back()->ComputeLatency(opts_.max_frames_latency); - for (size_t i = 0; i < nccf_info_.size(); i++) - delete nccf_info_[i]; - nccf_info_.clear(); -} - -OnlinePitchFeatureImpl::~OnlinePitchFeatureImpl() { - delete nccf_resampler_; - delete signal_resampler_; - for (size_t i = 0; i < frame_info_.size(); i++) - delete frame_info_[i]; - for (size_t i = 0; i < nccf_info_.size(); i++) - delete nccf_info_[i]; -} - -void OnlinePitchFeatureImpl::AcceptWaveform( - BaseFloat sampling_rate, - const VectorBase &wave) { - // flush out the last few samples of input waveform only if input_finished_ == - // true. - const bool flush = input_finished_; - - Vector downsampled_wave; - signal_resampler_->Resample(wave, flush, &downsampled_wave); - - // these variables will be used to compute the root-mean-square value of the - // signal for the ballast term. - double cur_sumsq = signal_sumsq_, cur_sum = signal_sum_; - int64 cur_num_samp = downsampled_samples_processed_, - prev_frame_end_sample = 0; - if (!opts_.nccf_ballast_online) { - cur_sumsq += VecVec(downsampled_wave, downsampled_wave); - cur_sum += downsampled_wave.Sum(); - cur_num_samp += downsampled_wave.Dim(); - } - - // end_frame is the total number of frames we can now process, including - // previously processed ones. - int32 end_frame = NumFramesAvailable( - downsampled_samples_processed_ + downsampled_wave.Dim(), opts_.snip_edges); - // "start_frame" is the first frame-index we process - int32 start_frame = frame_info_.size() - 1, - num_new_frames = end_frame - start_frame; - - if (num_new_frames == 0) { - UpdateRemainder(downsampled_wave); - return; - // continuing to the rest of the code would generate - // an error when sizing matrices with zero rows, and - // anyway is a waste of time. - } - - int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_, - num_resampled_lags = lags_.Dim(), - frame_shift = opts_.NccfWindowShift(), - basic_frame_length = opts_.NccfWindowSize(), - full_frame_length = basic_frame_length + nccf_last_lag_; - - Vector window(full_frame_length), - inner_prod(num_measured_lags), - norm_prod(num_measured_lags); - Matrix nccf_pitch(num_new_frames, num_measured_lags), - nccf_pov(num_new_frames, num_measured_lags); - - Vector cur_forward_cost(num_resampled_lags); - - - // Because the resampling of the NCCF is more efficient when grouped together, - // we first compute the NCCF for all frames, then resample as a matrix, then - // do the Viterbi [that happens inside the constructor of PitchFrameInfo]. - - for (int32 frame = start_frame; frame < end_frame; frame++) { - // start_sample is index into the whole wave, not just this part. - int64 start_sample; - if (opts_.snip_edges) { - // Usual case: offset starts at 0 - start_sample = static_cast(frame) * frame_shift; - } else { - // When we are not snipping the edges, the first offsets may be - // negative. In this case we will pad with zeros, it should not impact - // the pitch tracker. - start_sample = - static_cast((frame + 0.5) * frame_shift) - full_frame_length / 2; - } - ExtractFrame(downsampled_wave, start_sample, &window); - if (opts_.nccf_ballast_online) { - // use only up to end of current frame to compute root-mean-square value. - // end_sample will be the sample-index into "downsampled_wave", so - // not really comparable to start_sample. - int64 end_sample = start_sample + full_frame_length - - downsampled_samples_processed_; - KALDI_ASSERT(end_sample > 0); // or should have processed this frame last - // time. Note: end_sample is one past last - // sample. - if (end_sample > downsampled_wave.Dim()) { - KALDI_ASSERT(input_finished_); - end_sample = downsampled_wave.Dim(); - } - SubVector new_part(downsampled_wave, prev_frame_end_sample, - end_sample - prev_frame_end_sample); - cur_num_samp += new_part.Dim(); - cur_sumsq += VecVec(new_part, new_part); - cur_sum += new_part.Sum(); - prev_frame_end_sample = end_sample; - } - double mean_square = cur_sumsq / cur_num_samp - - pow(cur_sum / cur_num_samp, 2.0); - - ComputeCorrelation(window, nccf_first_lag_, nccf_last_lag_, - basic_frame_length, &inner_prod, &norm_prod); - double nccf_ballast_pov = 0.0, - nccf_ballast_pitch = pow(mean_square * basic_frame_length, 2) * - opts_.nccf_ballast, - avg_norm_prod = norm_prod.Sum() / norm_prod.Dim(); - SubVector nccf_pitch_row(nccf_pitch, frame - start_frame); - ComputeNccf(inner_prod, norm_prod, nccf_ballast_pitch, - &nccf_pitch_row); - SubVector nccf_pov_row(nccf_pov, frame - start_frame); - ComputeNccf(inner_prod, norm_prod, nccf_ballast_pov, - &nccf_pov_row); - if (frame < opts_.recompute_frame) - nccf_info_.push_back(new NccfInfo(avg_norm_prod, mean_square)); - } - - Matrix nccf_pitch_resampled(num_new_frames, num_resampled_lags); - nccf_resampler_->Resample(nccf_pitch, &nccf_pitch_resampled); - nccf_pitch.Resize(0, 0); // no longer needed. - Matrix nccf_pov_resampled(num_new_frames, num_resampled_lags); - nccf_resampler_->Resample(nccf_pov, &nccf_pov_resampled); - nccf_pov.Resize(0, 0); // no longer needed. - - // We've finished dealing with the waveform so we can call UpdateRemainder - // now; we need to call it before we possibly call RecomputeBacktraces() - // below, which is why we don't do it at the very end. - UpdateRemainder(downsampled_wave); - - std::vector > index_info; - - for (int32 frame = start_frame; frame < end_frame; frame++) { - int32 frame_idx = frame - start_frame; - PitchFrameInfo *prev_info = frame_info_.back(), - *cur_info = new PitchFrameInfo(prev_info); - cur_info->SetNccfPov(nccf_pov_resampled.Row(frame_idx)); - cur_info->ComputeBacktraces(opts_, nccf_pitch_resampled.Row(frame_idx), - lags_, forward_cost_, &index_info, - &cur_forward_cost); - forward_cost_.Swap(&cur_forward_cost); - // Renormalize forward_cost so smallest element is zero. - BaseFloat remainder = forward_cost_.Min(); - forward_cost_remainder_ += remainder; - forward_cost_.Add(-remainder); - frame_info_.push_back(cur_info); - if (frame < opts_.recompute_frame) - nccf_info_[frame]->nccf_pitch_resampled = - nccf_pitch_resampled.Row(frame_idx); - if (frame == opts_.recompute_frame - 1 && !opts_.nccf_ballast_online) - RecomputeBacktraces(); - } - - // Trace back the best-path. - int32 best_final_state; - forward_cost_.Min(&best_final_state); - lag_nccf_.resize(frame_info_.size() - 1); // will keep any existing data. - frame_info_.back()->SetBestState(best_final_state, lag_nccf_); - frames_latency_ = - frame_info_.back()->ComputeLatency(opts_.max_frames_latency); - KALDI_VLOG(4) << "Latency is " << frames_latency_; -} - - - -// Some functions that forward from OnlinePitchFeature to -// OnlinePitchFeatureImpl. -int32 OnlinePitchFeature::NumFramesReady() const { - return impl_->NumFramesReady(); -} - -OnlinePitchFeature::OnlinePitchFeature(const PitchExtractionOptions &opts) - :impl_(new OnlinePitchFeatureImpl(opts)) { } - -bool OnlinePitchFeature::IsLastFrame(int32 frame) const { - return impl_->IsLastFrame(frame); -} - -BaseFloat OnlinePitchFeature::FrameShiftInSeconds() const { - return impl_->FrameShiftInSeconds(); -} - -void OnlinePitchFeature::GetFrame(int32 frame, VectorBase *feat) { - impl_->GetFrame(frame, feat); -} - -void OnlinePitchFeature::AcceptWaveform( - BaseFloat sampling_rate, - const VectorBase &waveform) { - impl_->AcceptWaveform(sampling_rate, waveform); -} - -void OnlinePitchFeature::InputFinished() { - impl_->InputFinished(); -} - -OnlinePitchFeature::~OnlinePitchFeature() { - delete impl_; -} - - -/** - This function is called from ComputeKaldiPitch when the user - specifies opts.simulate_first_pass_online == true. It gives - the "first-pass" version of the features, which you would get - on the first decoding pass in an online setting. These may - differ slightly from the final features due to both the - way the Viterbi traceback works (this is affected by - opts.max_frames_latency), and the online way we compute - the average signal energy. -*/ -void ComputeKaldiPitchFirstPass( - const PitchExtractionOptions &opts, - const VectorBase &wave, - Matrix *output) { - - int32 cur_rows = 100; - Matrix feats(cur_rows, 2); - - OnlinePitchFeature pitch_extractor(opts); - KALDI_ASSERT(opts.frames_per_chunk > 0 && - "--simulate-first-pass-online option does not make sense " - "unless you specify --frames-per-chunk"); - - int32 cur_offset = 0, cur_frame = 0, samp_per_chunk = - opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; - - while (cur_offset < wave.Dim()) { - int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); - SubVector wave_chunk(wave, cur_offset, num_samp); - pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); - cur_offset += num_samp; - if (cur_offset == wave.Dim()) - pitch_extractor.InputFinished(); - // Get each frame as soon as it is ready. - for (; cur_frame < pitch_extractor.NumFramesReady(); cur_frame++) { - if (cur_frame >= cur_rows) { - cur_rows *= 2; - feats.Resize(cur_rows, 2, kCopyData); - } - SubVector row(feats, cur_frame); - pitch_extractor.GetFrame(cur_frame, &row); - } - } - if (cur_frame == 0) { - KALDI_WARN << "No features output since wave file too short"; - output->Resize(0, 0); - } else { - *output = feats.RowRange(0, cur_frame); - } -} - - - -void ComputeKaldiPitch(const PitchExtractionOptions &opts, - const VectorBase &wave, - Matrix *output) { - if (opts.simulate_first_pass_online) { - ComputeKaldiPitchFirstPass(opts, wave, output); - return; - } - OnlinePitchFeature pitch_extractor(opts); - - if (opts.frames_per_chunk == 0) { - pitch_extractor.AcceptWaveform(opts.samp_freq, wave); - } else { - // the user may set opts.frames_per_chunk for better compatibility with - // online operation. - KALDI_ASSERT(opts.frames_per_chunk > 0); - int32 cur_offset = 0, samp_per_chunk = - opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; - while (cur_offset < wave.Dim()) { - int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); - SubVector wave_chunk(wave, cur_offset, num_samp); - pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); - cur_offset += num_samp; - } - } - pitch_extractor.InputFinished(); - int32 num_frames = pitch_extractor.NumFramesReady(); - if (num_frames == 0) { - KALDI_WARN << "No frames output in pitch extraction"; - output->Resize(0, 0); - return; - } - output->Resize(num_frames, 2); - for (int32 frame = 0; frame < num_frames; frame++) { - SubVector row(*output, frame); - pitch_extractor.GetFrame(frame, &row); - } -} - - -/* - This comment describes our invesigation of how much latency the - online-processing algorithm introduces, i.e. how many frames you would - typically have to wait until the traceback converges, if you were to set the - --max-frames-latency to a very large value. - - This was done on a couple of files of language-id data. - - /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | wc - 4871 24355 443991 - /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | grep 100 | wc - 1534 7670 141128 - -# as above, but with 50 instead of 10 in the --max-frames-latency and grep statements. - 2070 10350 188370 -# as above, but with 10 instead of 50. - 4067 20335 370097 - - This says that out of 4871 selected frames [we measured the latency every 10 - frames, since --frames-per-chunk=10], in 1534 frames (31%), the latency was - >= 100 frames, i.e. >= 1 second. Including the other numbers, we can see - that - - 31% of frames had latency >= 1 second - 42% of frames had latency >= 0.5 second - 83% of frames had latency >= 0.1 second. - - This doesn't necessarily mean that we actually have a latency of >= 1 second 31% of - the time when using these features, since by using the --max-frames-latency option - (default: 30 frames), it will limit the latency to, say, 0.3 seconds, and trace back - from the best current pitch. Most of the time this will probably cause no change in - the pitch traceback since the best current pitch is probably the "right" point to - trace back from. And anyway, in the online-decoding, we will most likely rescore - the features at the end anyway, and the traceback gets recomputed, so there will - be no inaccuracy (assuming the first-pass lattice had everything we needed). - - Probably the greater source of inaccuracy due to the online algorithm is the - online energy-normalization, which affects the NCCF-ballast term, and which, - for reasons of efficiency, we don't attempt to "correct" in a later rescoring - pass. This will make the most difference in the first few frames of the file, - before the first voicing, where it will tend to produce more pitch movement - than the offline version of the algorithm. -*/ - - -// Function to do data accumulation for on-line usage -template -inline void AppendVector(const VectorBase &src, Vector *dst) { - if (src.Dim() == 0) return; - dst->Resize(dst->Dim() + src.Dim(), kCopyData); - dst->Range(dst->Dim() - src.Dim(), src.Dim()).CopyFromVec(src); -} - -/** - Note on the implementation of OnlineProcessPitch: the - OnlineFeatureInterface allows random access to features (i.e. not necessarily - sequential order), so we need to support that. But we don't need to support - it very efficiently, and our implementation is most efficient if frames are - accessed in sequential order. - - Also note: we have to be a bit careful in this implementation because - the input features may change. That is: if we call - src_->GetFrame(t, &vec) from GetFrame(), we can't guarantee that a later - call to src_->GetFrame(t, &vec) from another GetFrame() will return the - same value. In fact, while designing this class we used some knowledge - of how the OnlinePitchFeature class works to minimize the amount of - re-querying we had to do. -*/ -OnlineProcessPitch::OnlineProcessPitch( - const ProcessPitchOptions &opts, - OnlineFeatureInterface *src): - opts_(opts), src_(src), - dim_ ((opts.add_pov_feature ? 1 : 0) - + (opts.add_normalized_log_pitch ? 1 : 0) - + (opts.add_delta_pitch ? 1 : 0) - + (opts.add_raw_log_pitch ? 1 : 0)) { - KALDI_ASSERT(dim_ > 0 && - " At least one of the pitch features should be chosen. " - "Check your post-process-pitch options."); - KALDI_ASSERT(src->Dim() == kRawFeatureDim && - "Input feature must be pitch feature (should have dimension 2)"); -} - - -void OnlineProcessPitch::GetFrame(int32 frame, - VectorBase *feat) { - int32 frame_delayed = frame < opts_.delay ? 0 : frame - opts_.delay; - KALDI_ASSERT(feat->Dim() == dim_ && - frame_delayed < NumFramesReady()); - int32 index = 0; - if (opts_.add_pov_feature) - (*feat)(index++) = GetPovFeature(frame_delayed); - if (opts_.add_normalized_log_pitch) - (*feat)(index++) = GetNormalizedLogPitchFeature(frame_delayed); - if (opts_.add_delta_pitch) - (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); - if (opts_.add_raw_log_pitch) - (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); - KALDI_ASSERT(index == dim_); -} - -BaseFloat OnlineProcessPitch::GetPovFeature(int32 frame) const { - Vector tmp(kRawFeatureDim); - src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor - BaseFloat nccf = tmp(0); - return opts_.pov_scale * NccfToPovFeature(nccf) - + opts_.pov_offset; -} - -BaseFloat OnlineProcessPitch::GetDeltaPitchFeature(int32 frame) { - // Rather than computing the delta pitch directly in code here, - // which might seem easier, we accumulate a small window of features - // and call ComputeDeltas. This might seem like overkill; the reason - // we do it this way is to ensure that the end effects (at file - // beginning and end) are handled in a consistent way. - int32 context = opts_.delta_window; - int32 start_frame = std::max(0, frame - context), - end_frame = std::min(frame + context + 1, src_->NumFramesReady()), - frames_in_window = end_frame - start_frame; - Matrix feats(frames_in_window, 1), - delta_feats; - - for (int32 f = start_frame; f < end_frame; f++) - feats(f - start_frame, 0) = GetRawLogPitchFeature(f); - - DeltaFeaturesOptions delta_opts; - delta_opts.order = 1; - delta_opts.window = opts_.delta_window; - ComputeDeltas(delta_opts, feats, &delta_feats); - while (delta_feature_noise_.size() <= static_cast(frame)) { - delta_feature_noise_.push_back(RandGauss() * - opts_.delta_pitch_noise_stddev); - } - // note: delta_feats will have two columns, second contains deltas. - return (delta_feats(frame - start_frame, 1) + delta_feature_noise_[frame]) * - opts_.delta_pitch_scale; -} - -BaseFloat OnlineProcessPitch::GetRawLogPitchFeature(int32 frame) const { - Vector tmp(kRawFeatureDim); - src_->GetFrame(frame, &tmp); - BaseFloat pitch = tmp(1); - KALDI_ASSERT(pitch > 0); - return Log(pitch); -} - -BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { - UpdateNormalizationStats(frame); - BaseFloat log_pitch = GetRawLogPitchFeature(frame), - avg_log_pitch = normalization_stats_[frame].sum_log_pitch_pov / - normalization_stats_[frame].sum_pov, - normalized_log_pitch = log_pitch - avg_log_pitch; - return normalized_log_pitch * opts_.pitch_scale; -} - - -// inline -void OnlineProcessPitch::GetNormalizationWindow(int32 t, - int32 src_frames_ready, - int32 *window_begin, - int32 *window_end) const { - int32 left_context = opts_.normalization_left_context; - int32 right_context = opts_.normalization_right_context; - *window_begin = std::max(0, t - left_context); - *window_end = std::min(t + right_context + 1, src_frames_ready); -} - - -// Makes sure the entry in normalization_stats_ for this frame is up to date; -// called from GetNormalizedLogPitchFeature. -// the cur_num_frames and input_finished variables are needed because the -// pitch features for a given frame may change as we see more data. -void OnlineProcessPitch::UpdateNormalizationStats(int32 frame) { - KALDI_ASSERT(frame >= 0); - if (normalization_stats_.size() <= frame) - normalization_stats_.resize(frame + 1); - int32 cur_num_frames = src_->NumFramesReady(); - bool input_finished = src_->IsLastFrame(cur_num_frames - 1); - - NormalizationStats &this_stats = normalization_stats_[frame]; - if (this_stats.cur_num_frames == cur_num_frames && - this_stats.input_finished == input_finished) { - // Stats are fully up-to-date. - return; - } - int32 this_window_begin, this_window_end; - GetNormalizationWindow(frame, cur_num_frames, - &this_window_begin, &this_window_end); - - if (frame > 0) { - const NormalizationStats &prev_stats = normalization_stats_[frame - 1]; - if (prev_stats.cur_num_frames == cur_num_frames && - prev_stats.input_finished == input_finished) { - // we'll derive this_stats efficiently from prev_stats. - // Checking that cur_num_frames and input_finished have not changed - // ensures that the underlying features will not have changed. - this_stats = prev_stats; - int32 prev_window_begin, prev_window_end; - GetNormalizationWindow(frame - 1, cur_num_frames, - &prev_window_begin, &prev_window_end); - if (this_window_begin != prev_window_begin) { - KALDI_ASSERT(this_window_begin == prev_window_begin + 1); - Vector tmp(kRawFeatureDim); - src_->GetFrame(prev_window_begin, &tmp); - BaseFloat accurate_pov = NccfToPov(tmp(0)), - log_pitch = Log(tmp(1)); - this_stats.sum_pov -= accurate_pov; - this_stats.sum_log_pitch_pov -= accurate_pov * log_pitch; - } - if (this_window_end != prev_window_end) { - KALDI_ASSERT(this_window_end == prev_window_end + 1); - Vector tmp(kRawFeatureDim); - src_->GetFrame(prev_window_end, &tmp); - BaseFloat accurate_pov = NccfToPov(tmp(0)), - log_pitch = Log(tmp(1)); - this_stats.sum_pov += accurate_pov; - this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; - } - return; - } - } - // The way we do it here is not the most efficient way to do it; - // we'll see if it becomes a problem. The issue is we have to redo - // this computation from scratch each time we process a new chunk, which - // may be a little inefficient if the chunk-size is very small. - this_stats.cur_num_frames = cur_num_frames; - this_stats.input_finished = input_finished; - this_stats.sum_pov = 0.0; - this_stats.sum_log_pitch_pov = 0.0; - Vector tmp(kRawFeatureDim); - for (int32 f = this_window_begin; f < this_window_end; f++) { - src_->GetFrame(f, &tmp); - BaseFloat accurate_pov = NccfToPov(tmp(0)), - log_pitch = Log(tmp(1)); - this_stats.sum_pov += accurate_pov; - this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; - } -} - -int32 OnlineProcessPitch::NumFramesReady() const { - int32 src_frames_ready = src_->NumFramesReady(); - if (src_frames_ready == 0) { - return 0; - } else if (src_->IsLastFrame(src_frames_ready - 1)) { - return src_frames_ready + opts_.delay; - } else { - return std::max(0, src_frames_ready - - opts_.normalization_right_context + opts_.delay); - } -} - -void ProcessPitch(const ProcessPitchOptions &opts, - const MatrixBase &input, - Matrix *output) { - OnlineMatrixFeature pitch_feat(input); - - OnlineProcessPitch online_process_pitch(opts, &pitch_feat); - - output->Resize(online_process_pitch.NumFramesReady(), - online_process_pitch.Dim()); - for (int32 t = 0; t < online_process_pitch.NumFramesReady(); t++) { - SubVector row(*output, t); - online_process_pitch.GetFrame(t, &row); - } -} - - -void ComputeAndProcessKaldiPitch( - const PitchExtractionOptions &pitch_opts, - const ProcessPitchOptions &process_opts, - const VectorBase &wave, - Matrix *output) { - - OnlinePitchFeature pitch_extractor(pitch_opts); - - if (pitch_opts.simulate_first_pass_online) { - KALDI_ASSERT(pitch_opts.frames_per_chunk > 0 && - "--simulate-first-pass-online option does not make sense " - "unless you specify --frames-per-chunk"); - } - - OnlineProcessPitch post_process(process_opts, &pitch_extractor); - - int32 cur_rows = 100; - Matrix feats(cur_rows, post_process.Dim()); - - int32 cur_offset = 0, cur_frame = 0, - samp_per_chunk = pitch_opts.frames_per_chunk * - pitch_opts.samp_freq * pitch_opts.frame_shift_ms / 1000.0f; - - // We request the first-pass features as soon as they are available, - // regardless of whether opts.simulate_first_pass_online == true. If - // opts.simulate_first_pass_online == true this should - // not affect the features generated, but it helps us to test the code - // in a way that's closer to what online decoding would see. - - while (cur_offset < wave.Dim()) { - int32 num_samp; - if (samp_per_chunk > 0) - num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); - else // user left opts.frames_per_chunk at zero. - num_samp = wave.Dim(); - SubVector wave_chunk(wave, cur_offset, num_samp); - pitch_extractor.AcceptWaveform(pitch_opts.samp_freq, wave_chunk); - cur_offset += num_samp; - if (cur_offset == wave.Dim()) - pitch_extractor.InputFinished(); - - // Get each frame as soon as it is ready. - for (; cur_frame < post_process.NumFramesReady(); cur_frame++) { - if (cur_frame >= cur_rows) { - cur_rows *= 2; - feats.Resize(cur_rows, post_process.Dim(), kCopyData); - } - SubVector row(feats, cur_frame); - post_process.GetFrame(cur_frame, &row); - } - } - - if (pitch_opts.simulate_first_pass_online) { - if (cur_frame == 0) { - KALDI_WARN << "No features output since wave file too short"; - output->Resize(0, 0); - } else { - *output = feats.RowRange(0, cur_frame); - } - } else { - // want the "final" features for second pass, so get them again. - output->Resize(post_process.NumFramesReady(), post_process.Dim()); - for (int32 frame = 0; frame < post_process.NumFramesReady(); frame++) { - SubVector row(*output, frame); - post_process.GetFrame(frame, &row); - } - } -} - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/pitch-functions.h b/speechx/speechx/kaldi/feat/pitch-functions.h deleted file mode 100644 index 9edf6c9f..00000000 --- a/speechx/speechx/kaldi/feat/pitch-functions.h +++ /dev/null @@ -1,450 +0,0 @@ -// feat/pitch-functions.h - -// Copyright 2013 Pegah Ghahremani -// 2014 IMSL, PKU-HKUST (author: Wei Shi) -// 2014 Yanqing Sun, Junjie Wang, -// Daniel Povey, Korbinian Riedhammer -// Xin Lei - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_PITCH_FUNCTIONS_H_ -#define KALDI_FEAT_PITCH_FUNCTIONS_H_ - -#include -#include -#include -#include - -#include "base/kaldi-error.h" -#include "feat/mel-computations.h" -#include "feat/online-feature-itf.h" -#include "matrix/matrix-lib.h" -#include "util/common-utils.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - -struct PitchExtractionOptions { - // FrameExtractionOptions frame_opts; - BaseFloat samp_freq; // sample frequency in hertz - BaseFloat frame_shift_ms; // in milliseconds. - BaseFloat frame_length_ms; // in milliseconds. - BaseFloat preemph_coeff; // Preemphasis coefficient. [use is deprecated.] - BaseFloat min_f0; // min f0 to search (Hz) - BaseFloat max_f0; // max f0 to search (Hz) - BaseFloat soft_min_f0; // Minimum f0, applied in soft way, must not - // exceed min-f0 - BaseFloat penalty_factor; // cost factor for FO change - BaseFloat lowpass_cutoff; // cutoff frequency for Low pass filter - BaseFloat resample_freq; // Integer that determines filter width when - // upsampling NCCF - BaseFloat delta_pitch; // the pitch tolerance in pruning lags - BaseFloat nccf_ballast; // Increasing this factor reduces NCCF for - // quiet frames, helping ensure pitch - // continuity in unvoiced region - int32 lowpass_filter_width; // Integer that determines filter width of - // lowpass filter - int32 upsample_filter_width; // Integer that determines filter width when - // upsampling NCCF - - // Below are newer config variables, not present in the original paper, - // that relate to the online pitch extraction algorithm. - - // The maximum number of frames of latency that we allow the pitch-processing - // to introduce, for online operation. If you set this to a large value, - // there would be no inaccuracy from the Viterbi traceback (but it might make - // you wait to see the pitch). This is not very relevant for the online - // operation: normalization-right-context is more relevant, you - // can just leave this value at zero. - int32 max_frames_latency; - - // Only relevant for the function ComputeKaldiPitch which is called by - // compute-kaldi-pitch-feats. If nonzero, we provide the input as chunks of - // this size. This affects the energy normalization which has a small effect - // on the resulting features, especially at the beginning of a file. For best - // compatibility with online operation (e.g. if you plan to train models for - // the online-deocding setup), you might want to set this to a small value, - // like one frame. - int32 frames_per_chunk; - - // Only relevant for the function ComputeKaldiPitch which is called by - // compute-kaldi-pitch-feats, and only relevant if frames_per_chunk is - // nonzero. If true, it will query the features as soon as they are - // available, which simulates the first-pass features you would get in online - // decoding. If false, the features you will get will be the same as those - // available at the end of the utterance, after InputFinished() has been - // called: e.g. during lattice rescoring. - bool simulate_first_pass_online; - - // Only relevant for online operation or when emulating online operation - // (e.g. when setting frames_per_chunk). This is the frame-index on which we - // recompute the NCCF (e.g. frame-index 500 = after 5 seconds); if the - // segment ends before this we do it when the segment ends. We do this by - // re-computing the signal average energy, which affects the NCCF via the - // "ballast term", scaling the resampled NCCF by a factor derived from the - // average change in the "ballast term", and re-doing the backtrace - // computation. Making this infinity would be the most exact, but would - // introduce unwanted latency at the end of long utterances, for little - // benefit. - int32 recompute_frame; - - // This is a "hidden config" used only for testing the online pitch - // extraction. If true, we compute the signal root-mean-squared for the - // ballast term, only up to the current frame, rather than the end of the - // current chunk of signal. This makes the output insensitive to the - // chunking, which is useful for testing purposes. - bool nccf_ballast_online; - bool snip_edges; - PitchExtractionOptions(): - samp_freq(16000), - frame_shift_ms(10.0), - frame_length_ms(25.0), - preemph_coeff(0.0), - min_f0(50), - max_f0(400), - soft_min_f0(10.0), - penalty_factor(0.1), - lowpass_cutoff(1000), - resample_freq(4000), - delta_pitch(0.005), - nccf_ballast(7000), - lowpass_filter_width(1), - upsample_filter_width(5), - max_frames_latency(0), - frames_per_chunk(0), - simulate_first_pass_online(false), - recompute_frame(500), - nccf_ballast_online(false), - snip_edges(true) { } - - void Register(OptionsItf *opts) { - opts->Register("sample-frequency", &samp_freq, - "Waveform data sample frequency (must match the waveform " - "file, if specified there)"); - opts->Register("frame-length", &frame_length_ms, "Frame length in " - "milliseconds"); - opts->Register("frame-shift", &frame_shift_ms, "Frame shift in " - "milliseconds"); - opts->Register("preemphasis-coefficient", &preemph_coeff, - "Coefficient for use in signal preemphasis (deprecated)"); - opts->Register("min-f0", &min_f0, - "min. F0 to search for (Hz)"); - opts->Register("max-f0", &max_f0, - "max. F0 to search for (Hz)"); - opts->Register("soft-min-f0", &soft_min_f0, - "Minimum f0, applied in soft way, must not exceed min-f0"); - opts->Register("penalty-factor", &penalty_factor, - "cost factor for FO change."); - opts->Register("lowpass-cutoff", &lowpass_cutoff, - "cutoff frequency for LowPass filter (Hz) "); - opts->Register("resample-frequency", &resample_freq, - "Frequency that we down-sample the signal to. Must be " - "more than twice lowpass-cutoff"); - opts->Register("delta-pitch", &delta_pitch, - "Smallest relative change in pitch that our algorithm " - "measures"); - opts->Register("nccf-ballast", &nccf_ballast, - "Increasing this factor reduces NCCF for quiet frames"); - opts->Register("nccf-ballast-online", &nccf_ballast_online, - "This is useful mainly for debug; it affects how the NCCF " - "ballast is computed."); - opts->Register("lowpass-filter-width", &lowpass_filter_width, - "Integer that determines filter width of " - "lowpass filter, more gives sharper filter"); - opts->Register("upsample-filter-width", &upsample_filter_width, - "Integer that determines filter width when upsampling NCCF"); - opts->Register("frames-per-chunk", &frames_per_chunk, "Only relevant for " - "offline pitch extraction (e.g. compute-kaldi-pitch-feats), " - "you can set it to a small nonzero value, such as 10, for " - "better feature compatibility with online decoding (affects " - "energy normalization in the algorithm)"); - opts->Register("simulate-first-pass-online", &simulate_first_pass_online, - "If true, compute-kaldi-pitch-feats will output features " - "that correspond to what an online decoder would see in the " - "first pass of decoding-- not the final version of the " - "features, which is the default. Relevant if " - "--frames-per-chunk > 0"); - opts->Register("recompute-frame", &recompute_frame, "Only relevant for " - "online pitch extraction, or for compatibility with online " - "pitch extraction. A non-critical parameter; the frame at " - "which we recompute some of the forward pointers, after " - "revising our estimate of the signal energy. Relevant if" - "--frames-per-chunk > 0"); - opts->Register("max-frames-latency", &max_frames_latency, "Maximum number " - "of frames of latency that we allow pitch tracking to " - "introduce into the feature processing (affects output only " - "if --frames-per-chunk > 0 and " - "--simulate-first-pass-online=true"); - opts->Register("snip-edges", &snip_edges, "If this is set to false, the " - "incomplete frames near the ending edge won't be snipped, " - "so that the number of frames is the file size divided by " - "the frame-shift. This makes different types of features " - "give the same number of frames."); - } - /// Returns the window-size in samples, after resampling. This is the - /// "basic window size", not the full window size after extending by max-lag. - // Because of floating point representation, it is more reliable to divide - // by 1000 instead of multiplying by 0.001, but it is a bit slower. - int32 NccfWindowSize() const { - return static_cast(resample_freq * frame_length_ms / 1000.0); - } - /// Returns the window-shift in samples, after resampling. - int32 NccfWindowShift() const { - return static_cast(resample_freq * frame_shift_ms / 1000.0); - } -}; - -struct ProcessPitchOptions { - BaseFloat pitch_scale; // the final normalized-log-pitch feature is scaled - // with this value - BaseFloat pov_scale; // the final POV feature is scaled with this value - BaseFloat pov_offset; // An offset that can be added to the final POV - // feature (useful for online-decoding, where we don't - // do CMN to the pitch-derived features. - - BaseFloat delta_pitch_scale; - BaseFloat delta_pitch_noise_stddev; // stddev of noise we add to delta-pitch - int32 normalization_left_context; // left-context used for sliding-window - // normalization - int32 normalization_right_context; // this should be reduced in online - // decoding to reduce latency - - int32 delta_window; - int32 delay; - - bool add_pov_feature; - bool add_normalized_log_pitch; - bool add_delta_pitch; - bool add_raw_log_pitch; - - ProcessPitchOptions() : - pitch_scale(2.0), - pov_scale(2.0), - pov_offset(0.0), - delta_pitch_scale(10.0), - delta_pitch_noise_stddev(0.005), - normalization_left_context(75), - normalization_right_context(75), - delta_window(2), - delay(0), - add_pov_feature(true), - add_normalized_log_pitch(true), - add_delta_pitch(true), - add_raw_log_pitch(false) { } - - - void Register(ParseOptions *opts) { - opts->Register("pitch-scale", &pitch_scale, - "Scaling factor for the final normalized log-pitch value"); - opts->Register("pov-scale", &pov_scale, - "Scaling factor for final POV (probability of voicing) " - "feature"); - opts->Register("pov-offset", &pov_offset, - "This can be used to add an offset to the POV feature. " - "Intended for use in online decoding as a substitute for " - " CMN."); - opts->Register("delta-pitch-scale", &delta_pitch_scale, - "Term to scale the final delta log-pitch feature"); - opts->Register("delta-pitch-noise-stddev", &delta_pitch_noise_stddev, - "Standard deviation for noise we add to the delta log-pitch " - "(before scaling); should be about the same as delta-pitch " - "option to pitch creation. The purpose is to get rid of " - "peaks in the delta-pitch caused by discretization of pitch " - "values."); - opts->Register("normalization-left-context", &normalization_left_context, - "Left-context (in frames) for moving window normalization"); - opts->Register("normalization-right-context", &normalization_right_context, - "Right-context (in frames) for moving window normalization"); - opts->Register("delta-window", &delta_window, - "Number of frames on each side of central frame, to use for " - "delta window."); - opts->Register("delay", &delay, - "Number of frames by which the pitch information is " - "delayed."); - opts->Register("add-pov-feature", &add_pov_feature, - "If true, the warped NCCF is added to output features"); - opts->Register("add-normalized-log-pitch", &add_normalized_log_pitch, - "If true, the log-pitch with POV-weighted mean subtraction " - "over 1.5 second window is added to output features"); - opts->Register("add-delta-pitch", &add_delta_pitch, - "If true, time derivative of log-pitch is added to output " - "features"); - opts->Register("add-raw-log-pitch", &add_raw_log_pitch, - "If true, log(pitch) is added to output features"); - } -}; - - -// We don't want to expose the pitch-extraction internals here as it's -// quite complex, so we use a private implementation. -class OnlinePitchFeatureImpl; - - -// Note: to start on a new waveform, just construct a new version -// of this object. -class OnlinePitchFeature: public OnlineBaseFeature { - public: - explicit OnlinePitchFeature(const PitchExtractionOptions &opts); - - virtual int32 Dim() const { return 2; /* (NCCF, pitch) */ } - - virtual int32 NumFramesReady() const; - - virtual BaseFloat FrameShiftInSeconds() const; - - virtual bool IsLastFrame(int32 frame) const; - - /// Outputs the two-dimensional feature consisting of (pitch, NCCF). You - /// should probably post-process this using class OnlineProcessPitch. - virtual void GetFrame(int32 frame, VectorBase *feat); - - virtual void AcceptWaveform(BaseFloat sampling_rate, - const VectorBase &waveform); - - virtual void InputFinished(); - - virtual ~OnlinePitchFeature(); - - private: - OnlinePitchFeatureImpl *impl_; -}; - - -/// This online-feature class implements post processing of pitch features. -/// Inputs are original 2 dims (nccf, pitch). It can produce various -/// kinds of outputs, using the default options it will be (pov-feature, -/// normalized-log-pitch, delta-log-pitch). -class OnlineProcessPitch: public OnlineFeatureInterface { - public: - virtual int32 Dim() const { return dim_; } - - virtual bool IsLastFrame(int32 frame) const { - if (frame <= -1) - return src_->IsLastFrame(-1); - else if (frame < opts_.delay) - return src_->IsLastFrame(-1) == true ? false : src_->IsLastFrame(0); - else - return src_->IsLastFrame(frame - opts_.delay); - } - virtual BaseFloat FrameShiftInSeconds() const { - return src_->FrameShiftInSeconds(); - } - - virtual int32 NumFramesReady() const; - - virtual void GetFrame(int32 frame, VectorBase *feat); - - virtual ~OnlineProcessPitch() { } - - // Does not take ownership of "src". - OnlineProcessPitch(const ProcessPitchOptions &opts, - OnlineFeatureInterface *src); - - private: - enum { kRawFeatureDim = 2}; // anonymous enum to define a constant. - // kRawFeatureDim defines the dimension - // of the input: (nccf, pitch) - - ProcessPitchOptions opts_; - OnlineFeatureInterface *src_; - int32 dim_; // Output feature dimension, set in initializer. - - struct NormalizationStats { - int32 cur_num_frames; // value of src_->NumFramesReady() when - // "mean_pitch" was set. - bool input_finished; // true if input data was finished when - // "mean_pitch" was computed. - double sum_pov; // sum of pov over relevant range - double sum_log_pitch_pov; // sum of log(pitch) * pov over relevant range - - NormalizationStats(): cur_num_frames(-1), input_finished(false), - sum_pov(0.0), sum_log_pitch_pov(0.0) { } - }; - - std::vector delta_feature_noise_; - - std::vector normalization_stats_; - - /// Computes and returns the POV feature for this frame. - /// Called from GetFrame(). - inline BaseFloat GetPovFeature(int32 frame) const; - - /// Computes and returns the delta-log-pitch feature for this frame. - /// Called from GetFrame(). - inline BaseFloat GetDeltaPitchFeature(int32 frame); - - /// Computes and returns the raw log-pitch feature for this frame. - /// Called from GetFrame(). - inline BaseFloat GetRawLogPitchFeature(int32 frame) const; - - /// Computes and returns the mean-subtracted log-pitch feature for this frame. - /// Called from GetFrame(). - inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); - - /// Computes the normalization window sizes. - inline void GetNormalizationWindow(int32 frame, - int32 src_frames_ready, - int32 *window_begin, - int32 *window_end) const; - - /// Makes sure the entry in normalization_stats_ for this frame is up to date; - /// called from GetNormalizedLogPitchFeature. - inline void UpdateNormalizationStats(int32 frame); -}; - - -/// This function extracts (pitch, NCCF) per frame, using the pitch extraction -/// method described in "A Pitch Extraction Algorithm Tuned for Automatic Speech -/// Recognition", Pegah Ghahremani, Bagher BabaAli, Daniel Povey, Korbinian -/// Riedhammer, Jan Trmal and Sanjeev Khudanpur, ICASSP 2014. The output will -/// have as many rows as there are frames, and two columns corresponding to -/// (NCCF, pitch) -void ComputeKaldiPitch(const PitchExtractionOptions &opts, - const VectorBase &wave, - Matrix *output); - -/// This function processes the raw (NCCF, pitch) quantities computed by -/// ComputeKaldiPitch, and processes them into features. By default it will -/// output three-dimensional features, (POV-feature, mean-subtracted-log-pitch, -/// delta-of-raw-pitch), but this is configurable in the options. The number of -/// rows of "output" will be the number of frames (rows) in "input", and the -/// number of columns will be the number of different types of features -/// requested (by default, 3; 4 is the max). The four config variables -/// --add-pov-feature, --add-normalized-log-pitch, --add-delta-pitch, -/// --add-raw-log-pitch determine which features we create; by default we create -/// the first three. -void ProcessPitch(const ProcessPitchOptions &opts, - const MatrixBase &input, - Matrix *output); - -/// This function combines ComputeKaldiPitch and ProcessPitch. The reason -/// why we need a separate function to do this is in order to be able to -/// accurately simulate the online pitch-processing, for testing and for -/// training models matched to the "first-pass" features. It is sensitive to -/// the variables in pitch_opts that relate to online processing, -/// i.e. max_frames_latency, frames_per_chunk, simulate_first_pass_online, -/// recompute_frame. -void ComputeAndProcessKaldiPitch(const PitchExtractionOptions &pitch_opts, - const ProcessPitchOptions &process_opts, - const VectorBase &wave, - Matrix *output); - - -/// @} End of "addtogroup feat" -} // namespace kaldi -#endif // KALDI_FEAT_PITCH_FUNCTIONS_H_ diff --git a/speechx/speechx/kaldi/feat/resample.cc b/speechx/speechx/kaldi/feat/resample.cc deleted file mode 100644 index 11f4c62b..00000000 --- a/speechx/speechx/kaldi/feat/resample.cc +++ /dev/null @@ -1,377 +0,0 @@ -// feat/resample.cc - -// Copyright 2013 Pegah Ghahremani -// 2014 IMSL, PKU-HKUST (author: Wei Shi) -// 2014 Yanqing Sun, Junjie Wang -// 2014 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#include -#include -#include "feat/feature-functions.h" -#include "matrix/matrix-functions.h" -#include "feat/resample.h" - -namespace kaldi { - - -LinearResample::LinearResample(int32 samp_rate_in_hz, - int32 samp_rate_out_hz, - BaseFloat filter_cutoff_hz, - int32 num_zeros): - samp_rate_in_(samp_rate_in_hz), - samp_rate_out_(samp_rate_out_hz), - filter_cutoff_(filter_cutoff_hz), - num_zeros_(num_zeros) { - KALDI_ASSERT(samp_rate_in_hz > 0.0 && - samp_rate_out_hz > 0.0 && - filter_cutoff_hz > 0.0 && - filter_cutoff_hz*2 <= samp_rate_in_hz && - filter_cutoff_hz*2 <= samp_rate_out_hz && - num_zeros > 0); - - // base_freq is the frequency of the repeating unit, which is the gcd - // of the input frequencies. - int32 base_freq = Gcd(samp_rate_in_, samp_rate_out_); - input_samples_in_unit_ = samp_rate_in_ / base_freq; - output_samples_in_unit_ = samp_rate_out_ / base_freq; - - SetIndexesAndWeights(); - Reset(); -} - -int64 LinearResample::GetNumOutputSamples(int64 input_num_samp, - bool flush) const { - // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, - // where tick_freq is the least common multiple of samp_rate_in_ and - // samp_rate_out_. - int32 tick_freq = Lcm(samp_rate_in_, samp_rate_out_); - int32 ticks_per_input_period = tick_freq / samp_rate_in_; - - // work out the number of ticks in the time interval - // [ 0, input_num_samp/samp_rate_in_ ). - int64 interval_length_in_ticks = input_num_samp * ticks_per_input_period; - if (!flush) { - BaseFloat window_width = num_zeros_ / (2.0 * filter_cutoff_); - // To count the window-width in ticks we take the floor. This - // is because since we're looking for the largest integer num-out-samp - // that fits in the interval, which is open on the right, a reduction - // in interval length of less than a tick will never make a difference. - // For example, the largest integer in the interval [ 0, 2 ) and the - // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). - // So when we're subtracting the window-width we can ignore the fractional - // part. - int32 window_width_ticks = floor(window_width * tick_freq); - // The time-period of the output that we can sample gets reduced - // by the window-width (which is actually the distance from the - // center to the edge of the windowing function) if we're not - // "flushing the output". - interval_length_in_ticks -= window_width_ticks; - } - if (interval_length_in_ticks <= 0) - return 0; - int32 ticks_per_output_period = tick_freq / samp_rate_out_; - // Get the last output-sample in the closed interval, i.e. replacing [ ) with - // [ ]. Note: integer division rounds down. See - // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of - // the notation. - int64 last_output_samp = interval_length_in_ticks / ticks_per_output_period; - // We need the last output-sample in the open interval, so if it takes us to - // the end of the interval exactly, subtract one. - if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) - last_output_samp--; - // First output-sample index is zero, so the number of output samples - // is the last output-sample plus one. - int64 num_output_samp = last_output_samp + 1; - return num_output_samp; -} - -void LinearResample::SetIndexesAndWeights() { - first_index_.resize(output_samples_in_unit_); - weights_.resize(output_samples_in_unit_); - - double window_width = num_zeros_ / (2.0 * filter_cutoff_); - - for (int32 i = 0; i < output_samples_in_unit_; i++) { - double output_t = i / static_cast(samp_rate_out_); - double min_t = output_t - window_width, max_t = output_t + window_width; - // we do ceil on the min and floor on the max, because if we did it - // the other way around we would unnecessarily include indexes just - // outside the window, with zero coefficients. It's possible - // if the arguments to the ceil and floor expressions are integers - // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), - // that we unnecessarily include something with a zero coefficient, - // but this is only a slight efficiency issue. - int32 min_input_index = ceil(min_t * samp_rate_in_), - max_input_index = floor(max_t * samp_rate_in_), - num_indices = max_input_index - min_input_index + 1; - first_index_[i] = min_input_index; - weights_[i].Resize(num_indices); - for (int32 j = 0; j < num_indices; j++) { - int32 input_index = min_input_index + j; - double input_t = input_index / static_cast(samp_rate_in_), - delta_t = input_t - output_t; - // sign of delta_t doesn't matter. - weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; - } - } -} - - -// inline -void LinearResample::GetIndexes(int64 samp_out, - int64 *first_samp_in, - int32 *samp_out_wrapped) const { - // A unit is the smallest nonzero amount of time that is an exact - // multiple of the input and output sample periods. The unit index - // is the answer to "which numbered unit we are in". - int64 unit_index = samp_out / output_samples_in_unit_; - // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ - *samp_out_wrapped = static_cast(samp_out - - unit_index * output_samples_in_unit_); - *first_samp_in = first_index_[*samp_out_wrapped] + - unit_index * input_samples_in_unit_; -} - - -void LinearResample::Resample(const VectorBase &input, - bool flush, - Vector *output) { - int32 input_dim = input.Dim(); - int64 tot_input_samp = input_sample_offset_ + input_dim, - tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); - - KALDI_ASSERT(tot_output_samp >= output_sample_offset_); - - output->Resize(tot_output_samp - output_sample_offset_); - - // samp_out is the index into the total output signal, not just the part - // of it we are producing here. - for (int64 samp_out = output_sample_offset_; - samp_out < tot_output_samp; - samp_out++) { - int64 first_samp_in; - int32 samp_out_wrapped; - GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); - const Vector &weights = weights_[samp_out_wrapped]; - // first_input_index is the first index into "input" that we have a weight - // for. - int32 first_input_index = static_cast(first_samp_in - - input_sample_offset_); - BaseFloat this_output; - if (first_input_index >= 0 && - first_input_index + weights.Dim() <= input_dim) { - SubVector input_part(input, first_input_index, weights.Dim()); - this_output = VecVec(input_part, weights); - } else { // Handle edge cases. - this_output = 0.0; - for (int32 i = 0; i < weights.Dim(); i++) { - BaseFloat weight = weights(i); - int32 input_index = first_input_index + i; - if (input_index < 0 && input_remainder_.Dim() + input_index >= 0) { - this_output += weight * - input_remainder_(input_remainder_.Dim() + input_index); - } else if (input_index >= 0 && input_index < input_dim) { - this_output += weight * input(input_index); - } else if (input_index >= input_dim) { - // We're past the end of the input and are adding zero; should only - // happen if the user specified flush == true, or else we would not - // be trying to output this sample. - KALDI_ASSERT(flush); - } - } - } - int32 output_index = static_cast(samp_out - output_sample_offset_); - (*output)(output_index) = this_output; - } - - if (flush) { - Reset(); // Reset the internal state. - } else { - SetRemainder(input); - input_sample_offset_ = tot_input_samp; - output_sample_offset_ = tot_output_samp; - } -} - -void LinearResample::SetRemainder(const VectorBase &input) { - Vector old_remainder(input_remainder_); - // max_remainder_needed is the width of the filter from side to side, - // measured in input samples. you might think it should be half that, - // but you have to consider that you might be wanting to output samples - // that are "in the past" relative to the beginning of the latest - // input... anyway, storing more remainder than needed is not harmful. - int32 max_remainder_needed = ceil(samp_rate_in_ * num_zeros_ / - filter_cutoff_); - input_remainder_.Resize(max_remainder_needed); - for (int32 index = - input_remainder_.Dim(); index < 0; index++) { - // we interpret "index" as an offset from the end of "input" and - // from the end of input_remainder_. - int32 input_index = index + input.Dim(); - if (input_index >= 0) - input_remainder_(index + input_remainder_.Dim()) = input(input_index); - else if (input_index + old_remainder.Dim() >= 0) - input_remainder_(index + input_remainder_.Dim()) = - old_remainder(input_index + old_remainder.Dim()); - // else leave it at zero. - } -} - -void LinearResample::Reset() { - input_sample_offset_ = 0; - output_sample_offset_ = 0; - input_remainder_.Resize(0); -} - -/** Here, t is a time in seconds representing an offset from - the center of the windowed filter function, and FilterFunction(t) - returns the windowed filter function, described - in the header as h(t) = f(t)g(t), evaluated at t. -*/ -BaseFloat LinearResample::FilterFunc(BaseFloat t) const { - BaseFloat window, // raised-cosine (Hanning) window of width - // num_zeros_/2*filter_cutoff_ - filter; // sinc filter function - if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) - window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); - else - window = 0.0; // outside support of window function - if (t != 0) - filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); - else - filter = 2 * filter_cutoff_; // limit of the function at t = 0 - return filter * window; -} - - -ArbitraryResample::ArbitraryResample( - int32 num_samples_in, BaseFloat samp_rate_in, - BaseFloat filter_cutoff, const Vector &sample_points, - int32 num_zeros): - num_samples_in_(num_samples_in), - samp_rate_in_(samp_rate_in), - filter_cutoff_(filter_cutoff), - num_zeros_(num_zeros) { - KALDI_ASSERT(num_samples_in > 0 && samp_rate_in > 0.0 && - filter_cutoff > 0.0 && - filter_cutoff * 2.0 <= samp_rate_in - && num_zeros > 0); - // set up weights_ and indices_. Please try to keep all functions short and - SetIndexes(sample_points); - SetWeights(sample_points); -} - - -void ArbitraryResample::Resample(const MatrixBase &input, - MatrixBase *output) const { - // each row of "input" corresponds to the data to resample; - // the corresponding row of "output" is the resampled data. - - KALDI_ASSERT(input.NumRows() == output->NumRows() && - input.NumCols() == num_samples_in_ && - output->NumCols() == weights_.size()); - - Vector output_col(output->NumRows()); - for (int32 i = 0; i < NumSamplesOut(); i++) { - SubMatrix input_part(input, 0, input.NumRows(), - first_index_[i], - weights_[i].Dim()); - const Vector &weight_vec(weights_[i]); - output_col.AddMatVec(1.0, input_part, - kNoTrans, weight_vec, 0.0); - output->CopyColFromVec(output_col, i); - } -} - -void ArbitraryResample::Resample(const VectorBase &input, - VectorBase *output) const { - KALDI_ASSERT(input.Dim() == num_samples_in_ && - output->Dim() == weights_.size()); - - int32 output_dim = output->Dim(); - for (int32 i = 0; i < output_dim; i++) { - SubVector input_part(input, first_index_[i], weights_[i].Dim()); - (*output)(i) = VecVec(input_part, weights_[i]); - } -} - -void ArbitraryResample::SetIndexes(const Vector &sample_points) { - int32 num_samples = sample_points.Dim(); - first_index_.resize(num_samples); - weights_.resize(num_samples); - BaseFloat filter_width = num_zeros_ / (2.0 * filter_cutoff_); - for (int32 i = 0; i < num_samples; i++) { - // the t values are in seconds. - BaseFloat t = sample_points(i), - t_min = t - filter_width, t_max = t + filter_width; - int32 index_min = ceil(samp_rate_in_ * t_min), - index_max = floor(samp_rate_in_ * t_max); - // the ceil on index min and the floor on index_max are because there - // is no point using indices just outside the window (coeffs would be zero). - if (index_min < 0) - index_min = 0; - if (index_max >= num_samples_in_) - index_max = num_samples_in_ - 1; - first_index_[i] = index_min; - weights_[i].Resize(index_max - index_min + 1); - } -} - -void ArbitraryResample::SetWeights(const Vector &sample_points) { - int32 num_samples_out = NumSamplesOut(); - for (int32 i = 0; i < num_samples_out; i++) { - for (int32 j = 0 ; j < weights_[i].Dim(); j++) { - BaseFloat delta_t = sample_points(i) - - (first_index_[i] + j) / samp_rate_in_; - // Include at this point the factor of 1.0 / samp_rate_in_ which - // appears in the math. - weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; - } - } -} - -/** Here, t is a time in seconds representing an offset from - the center of the windowed filter function, and FilterFunction(t) - returns the windowed filter function, described - in the header as h(t) = f(t)g(t), evaluated at t. -*/ -BaseFloat ArbitraryResample::FilterFunc(BaseFloat t) const { - BaseFloat window, // raised-cosine (Hanning) window of width - // num_zeros_/2*filter_cutoff_ - filter; // sinc filter function - if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) - window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); - else - window = 0.0; // outside support of window function - if (t != 0.0) - filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); - else - filter = 2.0 * filter_cutoff_; // limit of the function at zero. - return filter * window; -} - -void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, - BaseFloat new_freq, Vector *new_wave) { - BaseFloat min_freq = std::min(orig_freq, new_freq); - BaseFloat lowpass_cutoff = 0.99 * 0.5 * min_freq; - int32 lowpass_filter_width = 6; - LinearResample resampler(orig_freq, new_freq, - lowpass_cutoff, lowpass_filter_width); - resampler.Resample(wave, true, new_wave); -} -} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/resample.h b/speechx/speechx/kaldi/feat/resample.h deleted file mode 100644 index e0b4688c..00000000 --- a/speechx/speechx/kaldi/feat/resample.h +++ /dev/null @@ -1,287 +0,0 @@ -// feat/resample.h - -// Copyright 2013 Pegah Ghahremani -// 2014 IMSL, PKU-HKUST (author: Wei Shi) -// 2014 Yanqing Sun, Junjie Wang -// 2014 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -#ifndef KALDI_FEAT_RESAMPLE_H_ -#define KALDI_FEAT_RESAMPLE_H_ - -#include -#include -#include -#include - - -#include "matrix/matrix-lib.h" -#include "util/common-utils.h" -#include "base/kaldi-error.h" - -namespace kaldi { -/// @addtogroup feat FeatureExtraction -/// @{ - -/** - \file[resample.h] - - This header contains declarations of classes for resampling signals. The - normal cases of resampling a signal are upsampling and downsampling - (increasing and decreasing the sample rate of a signal, respectively), - although the ArbitraryResample class allows a more generic case where - we want to get samples of a signal at uneven intervals (for instance, - log-spaced). - - The input signal is always evenly spaced, say sampled with frequency S, and - we assume the original signal was band-limited to S/2 or lower. The n'th - input sample x_n (with n = 0, 1, ...) is interpreted as the original - signal's value at time n/S. - - For resampling, it is convenient to view the input signal as a - continuous function x(t) of t, where each sample x_n becomes a delta function - with magnitude x_n/S, at time n/S. If we band limit this to the Nyquist - frequency S/2, we can show that this is the same as the original signal - that was sampled. [assuming the original signal was periodic and band - limited.] In general we want to bandlimit to lower than S/2, because - we don't have a perfect filter and also because if we want to resample - at a lower frequency than S, we need to bandlimit to below half of that. - Anyway, suppose we want to bandlimit to C, with 0 < C < S/2. The perfect - rectangular filter with cutoff C is the sinc function, - \f[ f(t) = 2C sinc(2Ct), \f] - where sinc is the normalized sinc function \f$ sinc(t) = sin(pi t) / (pi t) \f$, with - \f$ sinc(0) = 1 \f$. This is not a practical filter, though, because it has - infinite support. At the cost of less-than-perfect rolloff, we can choose - a suitable windowing function g(t), and use f(t) g(t) as the filter. For - a windowing function we choose raised-cosine (Hanning) window with support - on [-w/2C, w/2C], where w >= 2 is an integer chosen by the user. w = 1 - means we window the sinc function out to its first zero on the left and right, - w = 2 means the second zero, and so on; we normally choose w to be at least two. - We call this num_zeros, not w, in the code. - - Convolving the signal x(t) with this windowed filter h(t) = f(t)g(t) and evaluating the resulting - signal s(t) at an arbitrary time t is easy: we have - \f[ s(t) = 1/S \sum_n x_n h(t - n/S) \f]. - (note: the sign of t - n/S might be wrong, but it doesn't matter as the filter - and window are symmetric). - This is true for arbitrary values of t. What the class ArbitraryResample does - is to allow you to evaluate the signal for specified values of t. -*/ - - -/** - Class ArbitraryResample allows you to resample a signal (assumed zero outside - the sample region, not periodic) at arbitrary specified time values, which - don't have to be linearly spaced. The low-pass filter cutoff - "filter_cutoff_hz" should be less than half the sample rate; - "num_zeros" should probably be at least two preferably more; higher numbers give - sharper filters but will be less efficient. -*/ -class ArbitraryResample { - public: - ArbitraryResample(int32 num_samples_in, - BaseFloat samp_rate_hz, - BaseFloat filter_cutoff_hz, - const Vector &sample_points_secs, - int32 num_zeros); - - int32 NumSamplesIn() const { return num_samples_in_; } - - int32 NumSamplesOut() const { return weights_.size(); } - - /// This function does the resampling. - /// input.NumRows() and output.NumRows() should be equal - /// and nonzero. - /// input.NumCols() should equal NumSamplesIn() - /// and output.NumCols() should equal NumSamplesOut(). - void Resample(const MatrixBase &input, - MatrixBase *output) const; - - /// This version of the Resample function processes just - /// one vector. - void Resample(const VectorBase &input, - VectorBase *output) const; - private: - void SetIndexes(const Vector &sample_points); - - void SetWeights(const Vector &sample_points); - - BaseFloat FilterFunc(BaseFloat t) const; - - int32 num_samples_in_; - BaseFloat samp_rate_in_; - BaseFloat filter_cutoff_; - int32 num_zeros_; - - std::vector first_index_; // The first input-sample index that we sum - // over, for this output-sample index. - std::vector > weights_; -}; - - -/** - LinearResample is a special case of ArbitraryResample, where we want to - resample a signal at linearly spaced intervals (this means we want to - upsample or downsample the signal). It is more efficient than - ArbitraryResample because we can construct it just once. - - We require that the input and output sampling rate be specified as - integers, as this is an easy way to specify that their ratio be rational. -*/ - -class LinearResample { - public: - /// Constructor. We make the input and output sample rates integers, because - /// we are going to need to find a common divisor. This should just remind - /// you that they need to be integers. The filter cutoff needs to be less - /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros - /// controls the sharpness of the filter, more == sharper but less efficient. - /// We suggest around 4 to 10 for normal use. - LinearResample(int32 samp_rate_in_hz, - int32 samp_rate_out_hz, - BaseFloat filter_cutoff_hz, - int32 num_zeros); - - /// This function does the resampling. If you call it with flush == true and - /// you have never called it with flush == false, it just resamples the input - /// signal (it resizes the output to a suitable number of samples). - /// - /// You can also use this function to process a signal a piece at a time. - /// suppose you break it into piece1, piece2, ... pieceN. You can call - /// \code{.cc} - /// Resample(piece1, &output1, false); - /// Resample(piece2, &output2, false); - /// Resample(piece3, &output3, true); - /// \endcode - /// If you call it with flush == false, it won't output the last few samples - /// but will remember them, so that if you later give it a second piece of - /// the input signal it can process it correctly. - /// If your most recent call to the object was with flush == false, it will - /// have internal state; you can remove this by calling Reset(). - /// Empty input is acceptable. - void Resample(const VectorBase &input, - bool flush, - Vector *output); - - /// Calling the function Reset() resets the state of the object prior to - /// processing a new signal; it is only necessary if you have called - /// Resample(x, y, false) for some signal, leading to a remainder of the - /// signal being called, but then abandon processing the signal before calling - /// Resample(x, y, true) for the last piece. Call it unnecessarily between - /// signals will not do any harm. - void Reset(); - - //// Return the input and output sampling rates (for checks, for example) - inline int32 GetInputSamplingRate() { return samp_rate_in_; } - inline int32 GetOutputSamplingRate() { return samp_rate_out_; } - private: - /// This function outputs the number of output samples we will output - /// for a signal with "input_num_samp" input samples. If flush == true, - /// we return the largest n such that - /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), - /// and note that the interval is half-open. If flush == false, - /// define window_width as num_zeros / (2.0 * filter_cutoff_); - /// we return the largest n such that (n/samp_rate_out_) is in the interval - /// [ 0, input_num_samp/samp_rate_in_ - window_width ). - int64 GetNumOutputSamples(int64 input_num_samp, bool flush) const; - - - /// Given an output-sample index, this function outputs to *first_samp_in the - /// first input-sample index that we have a weight on (may be negative), - /// and to *samp_out_wrapped the index into weights_ where we can get the - /// corresponding weights on the input. - inline void GetIndexes(int64 samp_out, - int64 *first_samp_in, - int32 *samp_out_wrapped) const; - - void SetRemainder(const VectorBase &input); - - void SetIndexesAndWeights(); - - BaseFloat FilterFunc(BaseFloat) const; - - // The following variables are provided by the user. - int32 samp_rate_in_; - int32 samp_rate_out_; - BaseFloat filter_cutoff_; - int32 num_zeros_; - - int32 input_samples_in_unit_; ///< The number of input samples in the - ///< smallest repeating unit: num_samp_in_ = - ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, - ///< samp_rate_out_hz) - int32 output_samples_in_unit_; ///< The number of output samples in the - ///< smallest repeating unit: num_samp_out_ = - ///< samp_rate_out_hz / Gcd(samp_rate_in_hz, - ///< samp_rate_out_hz) - - - /// The first input-sample index that we sum over, for this output-sample - /// index. May be negative; any truncation at the beginning is handled - /// separately. This is just for the first few output samples, but we can - /// extrapolate the correct input-sample index for arbitrary output samples. - std::vector first_index_; - - /// Weights on the input samples, for this output-sample index. - std::vector > weights_; - - // the following variables keep track of where we are in a particular signal, - // if it is being provided over multiple calls to Resample(). - - int64 input_sample_offset_; ///< The number of input samples we have - ///< already received for this signal - ///< (including anything in remainder_) - int64 output_sample_offset_; ///< The number of samples we have already - ///< output for this signal. - Vector input_remainder_; ///< A small trailing part of the - ///< previously seen input signal. -}; - -/** - Downsample or upsample a waveform. This is a convenience wrapper for the - class 'LinearResample'. - The low-pass filter cutoff used in 'LinearResample' is 0.99 of the Nyquist, - where the Nyquist is half of the minimum of (orig_freq, new_freq). The - resampling is done with a symmetric FIR filter with N_z (number of zeros) - as 6. - - We compared the downsampling results with those from the sox resampling - toolkit. - Sox's design is inspired by Laurent De Soras' paper, - https://ccrma.stanford.edu/~jos/resample/Implementation.html - - Note: we expect that while orig_freq and new_freq are of type BaseFloat, they - are actually required to have exact integer values (like 16000 or 8000) with - a ratio between them that can be expressed as a rational number with - reasonably small integer factors. -*/ -void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, - BaseFloat new_freq, Vector *new_wave); - - -/// This function is deprecated. It is provided for backward compatibility, to avoid -/// breaking older code. -inline void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, - BaseFloat new_freq, Vector *new_wave) { - ResampleWaveform(orig_freq, wave, new_freq, new_wave); -} - - -/// @} End of "addtogroup feat" -} // namespace kaldi -#endif // KALDI_FEAT_RESAMPLE_H_ diff --git a/speechx/speechx/kaldi/feat/signal.cc b/speechx/speechx/kaldi/feat/signal.cc deleted file mode 100644 index a206d399..00000000 --- a/speechx/speechx/kaldi/feat/signal.cc +++ /dev/null @@ -1,129 +0,0 @@ -// feat/signal.cc - -// Copyright 2015 Tom Ko - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "base/kaldi-common.h" -#include "util/common-utils.h" -#include "feat/signal.h" - -namespace kaldi { - -void ElementwiseProductOfFft(const Vector &a, Vector *b) { - int32 num_fft_bins = a.Dim() / 2; - for (int32 i = 0; i < num_fft_bins; i++) { - // do complex multiplication - ComplexMul(a(2*i), a(2*i + 1), &((*b)(2*i)), &((*b)(2*i + 1))); - } -} - -void ConvolveSignals(const Vector &filter, Vector *signal) { - int32 signal_length = signal->Dim(); - int32 filter_length = filter.Dim(); - int32 output_length = signal_length + filter_length - 1; - Vector signal_padded(output_length); - signal_padded.SetZero(); - for (int32 i = 0; i < signal_length; i++) { - for (int32 j = 0; j < filter_length; j++) { - signal_padded(i + j) += (*signal)(i) * filter(j); - } - } - signal->Resize(output_length); - signal->CopyFromVec(signal_padded); -} - - -void FFTbasedConvolveSignals(const Vector &filter, Vector *signal) { - int32 signal_length = signal->Dim(); - int32 filter_length = filter.Dim(); - int32 output_length = signal_length + filter_length - 1; - - int32 fft_length = RoundUpToNearestPowerOfTwo(output_length); - KALDI_VLOG(1) << "fft_length for full signal convolution is " << fft_length; - - SplitRadixRealFft srfft(fft_length); - - Vector filter_padded(fft_length); - filter_padded.Range(0, filter_length).CopyFromVec(filter); - srfft.Compute(filter_padded.Data(), true); - - Vector signal_padded(fft_length); - signal_padded.Range(0, signal_length).CopyFromVec(*signal); - srfft.Compute(signal_padded.Data(), true); - - ElementwiseProductOfFft(filter_padded, &signal_padded); - - srfft.Compute(signal_padded.Data(), false); - signal_padded.Scale(1.0 / fft_length); - - signal->Resize(output_length); - signal->CopyFromVec(signal_padded.Range(0, output_length)); -} - -void FFTbasedBlockConvolveSignals(const Vector &filter, Vector *signal) { - int32 signal_length = signal->Dim(); - int32 filter_length = filter.Dim(); - int32 output_length = signal_length + filter_length - 1; - signal->Resize(output_length, kCopyData); - - KALDI_VLOG(1) << "Length of the filter is " << filter_length; - - int32 fft_length = RoundUpToNearestPowerOfTwo(4 * filter_length); - KALDI_VLOG(1) << "Best FFT length is " << fft_length; - - int32 block_length = fft_length - filter_length + 1; - KALDI_VLOG(1) << "Block size is " << block_length; - SplitRadixRealFft srfft(fft_length); - - Vector filter_padded(fft_length); - filter_padded.Range(0, filter_length).CopyFromVec(filter); - srfft.Compute(filter_padded.Data(), true); - - Vector temp_pad(filter_length - 1); - temp_pad.SetZero(); - Vector signal_block_padded(fft_length); - - for (int32 po = 0; po < output_length; po += block_length) { - // get a block of the signal - int32 process_length = std::min(block_length, output_length - po); - signal_block_padded.SetZero(); - signal_block_padded.Range(0, process_length).CopyFromVec(signal->Range(po, process_length)); - - srfft.Compute(signal_block_padded.Data(), true); - - ElementwiseProductOfFft(filter_padded, &signal_block_padded); - - srfft.Compute(signal_block_padded.Data(), false); - signal_block_padded.Scale(1.0 / fft_length); - - // combine the block - if (po + block_length < output_length) { // current block is not the last block - signal->Range(po, block_length).CopyFromVec(signal_block_padded.Range(0, block_length)); - signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad); - temp_pad.CopyFromVec(signal_block_padded.Range(block_length, filter_length - 1)); - } else { - signal->Range(po, output_length - po).CopyFromVec( - signal_block_padded.Range(0, output_length - po)); - if (filter_length - 1 < output_length - po) - signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad); - else - signal->Range(po, output_length - po).AddVec(1.0, temp_pad.Range(0, output_length - po)); - } - } -} -} - diff --git a/speechx/speechx/kaldi/feat/signal.h b/speechx/speechx/kaldi/feat/signal.h deleted file mode 100644 index c6c3eb50..00000000 --- a/speechx/speechx/kaldi/feat/signal.h +++ /dev/null @@ -1,58 +0,0 @@ -// feat/signal.h - -// Copyright 2015 Tom Ko - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_FEAT_SIGNAL_H_ -#define KALDI_FEAT_SIGNAL_H_ - -#include "base/kaldi-common.h" -#include "util/common-utils.h" - -namespace kaldi { - -/* - The following three functions are having the same functionality but - different implementations so as the efficiency. After the convolution, - the length of the signal will be extended to (original signal length + - filter length - 1). -*/ - -/* - This function implements a simple non-FFT-based convolution of two signals. - It is suggested to use the FFT-based convolution function which is more - efficient. -*/ -void ConvolveSignals(const Vector &filter, Vector *signal); - -/* - This function implements FFT-based convolution of two signals. - However this should be an inefficient version of BlockConvolveSignals() - as it processes the entire signal with a single FFT. -*/ -void FFTbasedConvolveSignals(const Vector &filter, Vector *signal); - -/* - This function implements FFT-based block convolution of two signals using - overlap-add method. This is an efficient way to evaluate the discrete - convolution of a long signal with a finite impulse response filter. -*/ -void FFTbasedBlockConvolveSignals(const Vector &filter, Vector *signal); - -} // namespace kaldi - -#endif // KALDI_FEAT_SIGNAL_H_ diff --git a/speechx/speechx/kaldi/matrix/CMakeLists.txt b/speechx/speechx/kaldi/matrix/CMakeLists.txt deleted file mode 100644 index a4dbde2e..00000000 --- a/speechx/speechx/kaldi/matrix/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ - -add_library(kaldi-matrix -compressed-matrix.cc -kaldi-matrix.cc -kaldi-vector.cc -matrix-functions.cc -optimization.cc -packed-matrix.cc -qr.cc -sparse-matrix.cc -sp-matrix.cc -srfft.cc -tp-matrix.cc -) - -target_link_libraries(kaldi-matrix gfortran kaldi-base libopenblas.a) diff --git a/speechx/speechx/kaldi/matrix/cblas-wrappers.h b/speechx/speechx/kaldi/matrix/cblas-wrappers.h deleted file mode 100644 index f869ab7e..00000000 --- a/speechx/speechx/kaldi/matrix/cblas-wrappers.h +++ /dev/null @@ -1,491 +0,0 @@ -// matrix/cblas-wrappers.h - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey); -// Haihua Xu; Wei Shi - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -#ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_ -#define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1 - - -#include -#include "matrix/sp-matrix.h" -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/matrix-functions.h" -#include "matrix/kaldi-blas.h" - -// Do not include this file directly. It is to be included -// by .cc files in this directory. - -namespace kaldi { - - -inline void cblas_Xcopy(const int N, const float *X, const int incX, float *Y, - const int incY) { - cblas_scopy(N, X, incX, Y, incY); -} - -inline void cblas_Xcopy(const int N, const double *X, const int incX, double *Y, - const int incY) { - cblas_dcopy(N, X, incX, Y, incY); -} - - -inline float cblas_Xasum(const int N, const float *X, const int incX) { - return cblas_sasum(N, X, incX); -} - -inline double cblas_Xasum(const int N, const double *X, const int incX) { - return cblas_dasum(N, X, incX); -} - -inline void cblas_Xrot(const int N, float *X, const int incX, float *Y, - const int incY, const float c, const float s) { - cblas_srot(N, X, incX, Y, incY, c, s); -} -inline void cblas_Xrot(const int N, double *X, const int incX, double *Y, - const int incY, const double c, const double s) { - cblas_drot(N, X, incX, Y, incY, c, s); -} -inline float cblas_Xdot(const int N, const float *const X, - const int incX, const float *const Y, - const int incY) { - return cblas_sdot(N, X, incX, Y, incY); -} -inline double cblas_Xdot(const int N, const double *const X, - const int incX, const double *const Y, - const int incY) { - return cblas_ddot(N, X, incX, Y, incY); -} -inline void cblas_Xaxpy(const int N, const float alpha, const float *X, - const int incX, float *Y, const int incY) { - cblas_saxpy(N, alpha, X, incX, Y, incY); -} -inline void cblas_Xaxpy(const int N, const double alpha, const double *X, - const int incX, double *Y, const int incY) { - cblas_daxpy(N, alpha, X, incX, Y, incY); -} -inline void cblas_Xscal(const int N, const float alpha, float *data, - const int inc) { - cblas_sscal(N, alpha, data, inc); -} -inline void cblas_Xscal(const int N, const double alpha, double *data, - const int inc) { - cblas_dscal(N, alpha, data, inc); -} -inline void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata, - const float *v, const int v_inc, - const float beta, float *y, const int y_inc) { - cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); -} -inline void cblas_Xspmv(const double alpha, const int num_rows, const double *Mdata, - const double *v, const int v_inc, - const double beta, double *y, const int y_inc) { - cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); -} -inline void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata, - const int num_rows, float *y, const int y_inc) { - cblas_stpmv(CblasRowMajor, CblasLower, static_cast(trans), - CblasNonUnit, num_rows, Mdata, y, y_inc); -} -inline void cblas_Xtpmv(MatrixTransposeType trans, const double *Mdata, - const int num_rows, double *y, const int y_inc) { - cblas_dtpmv(CblasRowMajor, CblasLower, static_cast(trans), - CblasNonUnit, num_rows, Mdata, y, y_inc); -} - - -inline void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata, - const int num_rows, float *y, const int y_inc) { - cblas_stpsv(CblasRowMajor, CblasLower, static_cast(trans), - CblasNonUnit, num_rows, Mdata, y, y_inc); -} -inline void cblas_Xtpsv(MatrixTransposeType trans, const double *Mdata, - const int num_rows, double *y, const int y_inc) { - cblas_dtpsv(CblasRowMajor, CblasLower, static_cast(trans), - CblasNonUnit, num_rows, Mdata, y, y_inc); -} - -// x = alpha * M * y + beta * x -inline void cblas_Xspmv(MatrixIndexT dim, float alpha, const float *Mdata, - const float *ydata, MatrixIndexT ystride, - float beta, float *xdata, MatrixIndexT xstride) { - cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, - ydata, ystride, beta, xdata, xstride); -} -inline void cblas_Xspmv(MatrixIndexT dim, double alpha, const double *Mdata, - const double *ydata, MatrixIndexT ystride, - double beta, double *xdata, MatrixIndexT xstride) { - cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, - ydata, ystride, beta, xdata, xstride); -} - -// Implements A += alpha * (x y' + y x'); A is symmetric matrix. -inline void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata, - MatrixIndexT incX, const float *Ydata, MatrixIndexT incY, - float *Adata) { - cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, - incX, Ydata, incY, Adata); -} -inline void cblas_Xspr2(MatrixIndexT dim, double alpha, const double *Xdata, - MatrixIndexT incX, const double *Ydata, MatrixIndexT incY, - double *Adata) { - cblas_dspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, - incX, Ydata, incY, Adata); -} - -// Implements A += alpha * (x x'); A is symmetric matrix. -inline void cblas_Xspr(MatrixIndexT dim, float alpha, const float *Xdata, - MatrixIndexT incX, float *Adata) { - cblas_sspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); -} -inline void cblas_Xspr(MatrixIndexT dim, double alpha, const double *Xdata, - MatrixIndexT incX, double *Adata) { - cblas_dspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); -} - -// sgemv,dgemv: y = alpha M x + beta y. -inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, - MatrixIndexT num_cols, float alpha, const float *Mdata, - MatrixIndexT stride, const float *xdata, - MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { - cblas_sgemv(CblasRowMajor, static_cast(trans), num_rows, - num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); -} -inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, - MatrixIndexT num_cols, double alpha, const double *Mdata, - MatrixIndexT stride, const double *xdata, - MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { - cblas_dgemv(CblasRowMajor, static_cast(trans), num_rows, - num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); -} - -// sgbmv, dgmmv: y = alpha M x + + beta * y. -inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, - MatrixIndexT num_cols, MatrixIndexT num_below, - MatrixIndexT num_above, float alpha, const float *Mdata, - MatrixIndexT stride, const float *xdata, - MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { - cblas_sgbmv(CblasRowMajor, static_cast(trans), num_rows, - num_cols, num_below, num_above, alpha, Mdata, stride, xdata, - incX, beta, ydata, incY); -} -inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, - MatrixIndexT num_cols, MatrixIndexT num_below, - MatrixIndexT num_above, double alpha, const double *Mdata, - MatrixIndexT stride, const double *xdata, - MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { - cblas_dgbmv(CblasRowMajor, static_cast(trans), num_rows, - num_cols, num_below, num_above, alpha, Mdata, stride, xdata, - incX, beta, ydata, incY); -} - - -template -inline void Xgemv_sparsevec(MatrixTransposeType trans, MatrixIndexT num_rows, - MatrixIndexT num_cols, Real alpha, const Real *Mdata, - MatrixIndexT stride, const Real *xdata, - MatrixIndexT incX, Real beta, Real *ydata, - MatrixIndexT incY) { - if (trans == kNoTrans) { - if (beta != 1.0) cblas_Xscal(num_rows, beta, ydata, incY); - for (MatrixIndexT i = 0; i < num_cols; i++) { - Real x_i = xdata[i * incX]; - if (x_i == 0.0) continue; - // Add to ydata, the i'th column of M, times alpha * x_i - cblas_Xaxpy(num_rows, x_i * alpha, Mdata + i, stride, ydata, incY); - } - } else { - if (beta != 1.0) cblas_Xscal(num_cols, beta, ydata, incY); - for (MatrixIndexT i = 0; i < num_rows; i++) { - Real x_i = xdata[i * incX]; - if (x_i == 0.0) continue; - // Add to ydata, the i'th row of M, times alpha * x_i - cblas_Xaxpy(num_cols, x_i * alpha, - Mdata + (i * stride), 1, ydata, incY); - } - } -} - -inline void cblas_Xgemm(const float alpha, - MatrixTransposeType transA, - const float *Adata, - MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, - MatrixTransposeType transB, - const float *Bdata, MatrixIndexT b_stride, - const float beta, - float *Mdata, - MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { - cblas_sgemm(CblasRowMajor, static_cast(transA), - static_cast(transB), - num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, - alpha, Adata, a_stride, Bdata, b_stride, - beta, Mdata, stride); -} -inline void cblas_Xgemm(const double alpha, - MatrixTransposeType transA, - const double *Adata, - MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, - MatrixTransposeType transB, - const double *Bdata, MatrixIndexT b_stride, - const double beta, - double *Mdata, - MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { - cblas_dgemm(CblasRowMajor, static_cast(transA), - static_cast(transB), - num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, - alpha, Adata, a_stride, Bdata, b_stride, - beta, Mdata, stride); -} - - -inline void cblas_Xsymm(const float alpha, - MatrixIndexT sz, - const float *Adata,MatrixIndexT a_stride, - const float *Bdata,MatrixIndexT b_stride, - const float beta, - float *Mdata, MatrixIndexT stride) { - cblas_ssymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, - a_stride, Bdata, b_stride, beta, Mdata, stride); -} -inline void cblas_Xsymm(const double alpha, - MatrixIndexT sz, - const double *Adata,MatrixIndexT a_stride, - const double *Bdata,MatrixIndexT b_stride, - const double beta, - double *Mdata, MatrixIndexT stride) { - cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, - a_stride, Bdata, b_stride, beta, Mdata, stride); -} -// ger: M += alpha x y^T. -inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, - const float *xdata, MatrixIndexT incX, const float *ydata, - MatrixIndexT incY, float *Mdata, MatrixIndexT stride) { - cblas_sger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, - Mdata, stride); -} -inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, double alpha, - const double *xdata, MatrixIndexT incX, const double *ydata, - MatrixIndexT incY, double *Mdata, MatrixIndexT stride) { - cblas_dger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, - Mdata, stride); -} - -// syrk: symmetric rank-k update. -// if trans==kNoTrans, then C = alpha A A^T + beta C -// else C = alpha A^T A + beta C. -// note: dim_c is dim(C), other_dim_a is the "other" dimension of A, i.e. -// num-cols(A) if kNoTrans, or num-rows(A) if kTrans. -// We only need the row-major and lower-triangular option of this, and this -// is hard-coded. -inline void cblas_Xsyrk ( - const MatrixTransposeType trans, const MatrixIndexT dim_c, - const MatrixIndexT other_dim_a, const float alpha, const float *A, - const MatrixIndexT a_stride, const float beta, float *C, - const MatrixIndexT c_stride) { - cblas_ssyrk(CblasRowMajor, CblasLower, static_cast(trans), - dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); -} - -inline void cblas_Xsyrk( - const MatrixTransposeType trans, const MatrixIndexT dim_c, - const MatrixIndexT other_dim_a, const double alpha, const double *A, - const MatrixIndexT a_stride, const double beta, double *C, - const MatrixIndexT c_stride) { - cblas_dsyrk(CblasRowMajor, CblasLower, static_cast(trans), - dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); -} - -/// matrix-vector multiply using a banded matrix; we always call this -/// with b = 1 meaning we're multiplying by a diagonal matrix. This is used for -/// elementwise multiplication. We miss some of the arguments out of this -/// wrapper. -inline void cblas_Xsbmv1( - const MatrixIndexT dim, - const double *A, - const double alpha, - const double *x, - const double beta, - double *y) { - cblas_dsbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, - 1, x, 1, beta, y, 1); -} - -inline void cblas_Xsbmv1( - const MatrixIndexT dim, - const float *A, - const float alpha, - const float *x, - const float beta, - float *y) { - cblas_ssbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, - 1, x, 1, beta, y, 1); -} - -/// This is not really a wrapper for CBLAS as CBLAS does not have this; in future we could -/// extend this somehow. -inline void mul_elements( - const MatrixIndexT dim, - const double *a, - double *b) { // does b *= a, elementwise. - double c1, c2, c3, c4; - MatrixIndexT i; - for (i = 0; i + 4 <= dim; i += 4) { - c1 = a[i] * b[i]; - c2 = a[i+1] * b[i+1]; - c3 = a[i+2] * b[i+2]; - c4 = a[i+3] * b[i+3]; - b[i] = c1; - b[i+1] = c2; - b[i+2] = c3; - b[i+3] = c4; - } - for (; i < dim; i++) - b[i] *= a[i]; -} - -inline void mul_elements( - const MatrixIndexT dim, - const float *a, - float *b) { // does b *= a, elementwise. - float c1, c2, c3, c4; - MatrixIndexT i; - for (i = 0; i + 4 <= dim; i += 4) { - c1 = a[i] * b[i]; - c2 = a[i+1] * b[i+1]; - c3 = a[i+2] * b[i+2]; - c4 = a[i+3] * b[i+3]; - b[i] = c1; - b[i+1] = c2; - b[i+2] = c3; - b[i+3] = c4; - } - for (; i < dim; i++) - b[i] *= a[i]; -} - - - -// add clapack here -#if !defined(HAVE_ATLAS) -inline void clapack_Xtptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *result) { - stptri_(const_cast("U"), const_cast("N"), num_rows, Mdata, result); -} -inline void clapack_Xtptri(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *result) { - dtptri_(const_cast("U"), const_cast("N"), num_rows, Mdata, result); -} -// -inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, - float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, - KaldiBlasInt *result) { - sgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); -} -inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, - double *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, - KaldiBlasInt *result) { - dgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); -} - -// -inline void clapack_Xgetri2(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, - KaldiBlasInt *pivot, float *p_work, - KaldiBlasInt *l_work, KaldiBlasInt *result) { - sgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); -} -inline void clapack_Xgetri2(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, - KaldiBlasInt *pivot, double *p_work, - KaldiBlasInt *l_work, KaldiBlasInt *result) { - dgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); -} -// -inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, - KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, - float *sv, float *Vdata, KaldiBlasInt *vstride, - float *Udata, KaldiBlasInt *ustride, float *p_work, - KaldiBlasInt *l_work, KaldiBlasInt *result) { - sgesvd_(v, u, - num_cols, num_rows, Mdata, stride, - sv, Vdata, vstride, Udata, ustride, - p_work, l_work, result); -} -inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, - KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, - double *sv, double *Vdata, KaldiBlasInt *vstride, - double *Udata, KaldiBlasInt *ustride, double *p_work, - KaldiBlasInt *l_work, KaldiBlasInt *result) { - dgesvd_(v, u, - num_cols, num_rows, Mdata, stride, - sv, Vdata, vstride, Udata, ustride, - p_work, l_work, result); -} -// -void inline clapack_Xsptri(KaldiBlasInt *num_rows, float *Mdata, - KaldiBlasInt *ipiv, float *work, KaldiBlasInt *result) { - ssptri_(const_cast("U"), num_rows, Mdata, ipiv, work, result); -} -void inline clapack_Xsptri(KaldiBlasInt *num_rows, double *Mdata, - KaldiBlasInt *ipiv, double *work, KaldiBlasInt *result) { - dsptri_(const_cast("U"), num_rows, Mdata, ipiv, work, result); -} -// -void inline clapack_Xsptrf(KaldiBlasInt *num_rows, float *Mdata, - KaldiBlasInt *ipiv, KaldiBlasInt *result) { - ssptrf_(const_cast("U"), num_rows, Mdata, ipiv, result); -} -void inline clapack_Xsptrf(KaldiBlasInt *num_rows, double *Mdata, - KaldiBlasInt *ipiv, KaldiBlasInt *result) { - dsptrf_(const_cast("U"), num_rows, Mdata, ipiv, result); -} -#else -inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, - float *Mdata, MatrixIndexT stride, - int *pivot, int *result) { - *result = clapack_sgetrf(CblasColMajor, num_rows, num_cols, - Mdata, stride, pivot); -} - -inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, - double *Mdata, MatrixIndexT stride, - int *pivot, int *result) { - *result = clapack_dgetrf(CblasColMajor, num_rows, num_cols, - Mdata, stride, pivot); -} -// -inline int clapack_Xtrtri(int num_rows, float *Mdata, MatrixIndexT stride) { - return clapack_strtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, - Mdata, stride); -} - -inline int clapack_Xtrtri(int num_rows, double *Mdata, MatrixIndexT stride) { - return clapack_dtrtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, - Mdata, stride); -} -// -inline void clapack_Xgetri(MatrixIndexT num_rows, float *Mdata, MatrixIndexT stride, - int *pivot, int *result) { - *result = clapack_sgetri(CblasColMajor, num_rows, Mdata, stride, pivot); -} -inline void clapack_Xgetri(MatrixIndexT num_rows, double *Mdata, MatrixIndexT stride, - int *pivot, int *result) { - *result = clapack_dgetri(CblasColMajor, num_rows, Mdata, stride, pivot); -} -#endif - -} -// namespace kaldi - -#endif diff --git a/speechx/speechx/kaldi/matrix/compressed-matrix.cc b/speechx/speechx/kaldi/matrix/compressed-matrix.cc deleted file mode 100644 index 13214b25..00000000 --- a/speechx/speechx/kaldi/matrix/compressed-matrix.cc +++ /dev/null @@ -1,876 +0,0 @@ -// matrix/compressed-matrix.cc - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey) -// Frantisek Skala, Wei Shi -// 2015 Tom Ko - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "matrix/compressed-matrix.h" -#include - -namespace kaldi { - -//static -MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { - // Returns size in bytes of the data. - DataFormat format = static_cast(header.format); - if (format == kOneByteWithColHeaders) { - return sizeof(GlobalHeader) + - header.num_cols * (sizeof(PerColHeader) + header.num_rows); - } else if (format == kTwoByte) { - return sizeof(GlobalHeader) + - 2 * header.num_rows * header.num_cols; - } else { - KALDI_ASSERT(format == kOneByte); - return sizeof(GlobalHeader) + - header.num_rows * header.num_cols; - } -} - -// scale all element of matrix by scaling floats -// in GlobalHeader with alpha. -void CompressedMatrix::Scale(float alpha) { - if (data_ != NULL) { - GlobalHeader *h = reinterpret_cast(data_); - // scale the floating point values in each PerColHolder - // and leave all integers the same. - h->min_value *= alpha; - h->range *= alpha; - } -} - -template // static inline -void CompressedMatrix::ComputeGlobalHeader( - const MatrixBase &mat, CompressionMethod method, - GlobalHeader *header) { - if (method == kAutomaticMethod) { - if (mat.NumRows() > 8) method = kSpeechFeature; - else method = kTwoByteAuto; - } - - switch (method) { - case kSpeechFeature: - header->format = static_cast(kOneByteWithColHeaders); // 1. - break; - case kTwoByteAuto: case kTwoByteSignedInteger: - header->format = static_cast(kTwoByte); // 2. - break; - case kOneByteAuto: case kOneByteUnsignedInteger: case kOneByteZeroOne: - header->format = static_cast(kOneByte); // 3. - break; - default: - KALDI_ERR << "Invalid compression type: " - << static_cast(method); - } - - header->num_rows = mat.NumRows(); - header->num_cols = mat.NumCols(); - - // Now compute 'min_value' and 'range'. - switch (method) { - case kSpeechFeature: case kTwoByteAuto: case kOneByteAuto: { - float min_value = mat.Min(), max_value = mat.Max(); - // ensure that max_value is strictly greater than min_value, even if matrix is - // constant; this avoids crashes in ComputeColHeader when compressing speech - // featupres. - if (max_value == min_value) - max_value = min_value + (1.0 + fabs(min_value)); - KALDI_ASSERT(min_value - min_value == 0 && - max_value - max_value == 0 && - "Cannot compress a matrix with Nan's or Inf's"); - - header->min_value = min_value; - header->range = max_value - min_value; - - // we previously checked that max_value != min_value, so their - // difference should be nonzero. - KALDI_ASSERT(header->range > 0.0); - break; - } - case kTwoByteSignedInteger: { - header->min_value = -32768.0; - header->range = 65535.0; - break; - } - case kOneByteUnsignedInteger: { - header->min_value = 0.0; - header->range = 255.0; - break; - } - case kOneByteZeroOne: { - header->min_value = 0.0; - header->range = 1.0; - break; - } - default: - KALDI_ERR << "Unknown compression method = " - << static_cast(method); - } - KALDI_COMPILE_TIME_ASSERT(sizeof(*header) == 20); // otherwise - // something weird is happening and our code probably won't work or - // won't be robust across platforms. -} - -template -void CompressedMatrix::CopyFromMat( - const MatrixBase &mat, CompressionMethod method) { - if (data_ != NULL) { - delete [] static_cast(data_); // call delete [] because was allocated with new float[] - data_ = NULL; - } - if (mat.NumRows() == 0) { return; } // Zero-size matrix stored as zero pointer. - - - GlobalHeader global_header; - ComputeGlobalHeader(mat, method, &global_header); - - int32 data_size = DataSize(global_header); - - data_ = AllocateData(data_size); - - *(reinterpret_cast(data_)) = global_header; - - DataFormat format = static_cast(global_header.format); - if (format == kOneByteWithColHeaders) { - PerColHeader *header_data = - reinterpret_cast(static_cast(data_) + - sizeof(GlobalHeader)); - uint8 *byte_data = - reinterpret_cast(header_data + global_header.num_cols); - - const Real *matrix_data = mat.Data(); - - for (int32 col = 0; col < global_header.num_cols; col++) { - CompressColumn(global_header, - matrix_data + col, mat.Stride(), - global_header.num_rows, - header_data, byte_data); - header_data++; - byte_data += global_header.num_rows; - } - } else if (format == kTwoByte) { - uint16 *data = reinterpret_cast(static_cast(data_) + - sizeof(GlobalHeader)); - int32 num_rows = mat.NumRows(), num_cols = mat.NumCols(); - for (int32 r = 0; r < num_rows; r++) { - const Real *row_data = mat.RowData(r); - for (int32 c = 0; c < num_cols; c++) - data[c] = FloatToUint16(global_header, row_data[c]); - data += num_cols; - } - } else { - KALDI_ASSERT(format == kOneByte); - uint8 *data = reinterpret_cast(static_cast(data_) + - sizeof(GlobalHeader)); - int32 num_rows = mat.NumRows(), num_cols = mat.NumCols(); - for (int32 r = 0; r < num_rows; r++) { - const Real *row_data = mat.RowData(r); - for (int32 c = 0; c < num_cols; c++) - data[c] = FloatToUint8(global_header, row_data[c]); - data += num_cols; - } - } -} - -// Instantiate the template for float and double. -template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat, - CompressionMethod method); - -template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat, - CompressionMethod method); - - -CompressedMatrix::CompressedMatrix( - const CompressedMatrix &cmat, - const MatrixIndexT row_offset, - const MatrixIndexT num_rows, - const MatrixIndexT col_offset, - const MatrixIndexT num_cols, - bool allow_padding): data_(NULL) { - int32 old_num_rows = cmat.NumRows(), old_num_cols = cmat.NumCols(); - - if (old_num_rows == 0) { - KALDI_ASSERT(num_rows == 0 && num_cols == 0); - // The empty matrix is stored as a zero pointer. - return; - } - - KALDI_ASSERT(row_offset < old_num_rows); - KALDI_ASSERT(col_offset < old_num_cols); - KALDI_ASSERT(row_offset >= 0 || allow_padding); - KALDI_ASSERT(col_offset >= 0); - KALDI_ASSERT(row_offset + num_rows <= old_num_rows || allow_padding); - KALDI_ASSERT(col_offset + num_cols <= old_num_cols); - - if (num_rows == 0 || num_cols == 0) { return; } - - bool padding_is_used = (row_offset < 0 || - row_offset + num_rows > old_num_rows); - - GlobalHeader new_global_header; - KALDI_COMPILE_TIME_ASSERT(sizeof(new_global_header) == 20); - - GlobalHeader *old_global_header = reinterpret_cast(cmat.Data()); - - new_global_header = *old_global_header; - new_global_header.num_cols = num_cols; - new_global_header.num_rows = num_rows; - - // We don't switch format from 1 -> 2 (in case of size reduction) yet; if this - // is needed, we will do this below by creating a temporary Matrix. - new_global_header.format = old_global_header->format; - - data_ = AllocateData(DataSize(new_global_header)); // allocate memory - *(reinterpret_cast(data_)) = new_global_header; - - - DataFormat format = static_cast(old_global_header->format); - if (format == kOneByteWithColHeaders) { - PerColHeader *old_per_col_header = - reinterpret_cast(old_global_header + 1); - uint8 *old_byte_data = - reinterpret_cast(old_per_col_header + - old_global_header->num_cols); - PerColHeader *new_per_col_header = - reinterpret_cast( - reinterpret_cast(data_) + 1); - - memcpy(new_per_col_header, old_per_col_header + col_offset, - sizeof(PerColHeader) * num_cols); - - uint8 *new_byte_data = - reinterpret_cast(new_per_col_header + num_cols); - if (!padding_is_used) { - uint8 *old_start_of_subcol = - old_byte_data + row_offset + (col_offset * old_num_rows), - *new_start_of_col = new_byte_data; - for (int32 i = 0; i < num_cols; i++) { - memcpy(new_start_of_col, old_start_of_subcol, num_rows); - new_start_of_col += num_rows; - old_start_of_subcol += old_num_rows; - } - } else { - uint8 *old_start_of_col = - old_byte_data + (col_offset * old_num_rows), - *new_start_of_col = new_byte_data; - for (int32 i = 0; i < num_cols; i++) { - - for (int32 j = 0; j < num_rows; j++) { - int32 old_j = j + row_offset; - if (old_j < 0) old_j = 0; - else if (old_j >= old_num_rows) old_j = old_num_rows - 1; - new_start_of_col[j] = old_start_of_col[old_j]; - } - new_start_of_col += num_rows; - old_start_of_col += old_num_rows; - } - } - } else if (format == kTwoByte) { - const uint16 *old_data = - reinterpret_cast(old_global_header + 1); - uint16 *new_row_data = - reinterpret_cast(reinterpret_cast(data_) + 1); - - for (int32 row = 0; row < num_rows; row++) { - int32 old_row = row + row_offset; - // The next two lines are only relevant if padding_is_used. - if (old_row < 0) old_row = 0; - else if (old_row >= old_num_rows) old_row = old_num_rows - 1; - const uint16 *old_row_data = - old_data + col_offset + (old_num_cols * old_row); - memcpy(new_row_data, old_row_data, sizeof(uint16) * num_cols); - new_row_data += num_cols; - } - } else { - KALDI_ASSERT(format == kOneByte); - const uint8 *old_data = - reinterpret_cast(old_global_header + 1); - uint8 *new_row_data = - reinterpret_cast(reinterpret_cast(data_) + 1); - - for (int32 row = 0; row < num_rows; row++) { - int32 old_row = row + row_offset; - // The next two lines are only relevant if padding_is_used. - if (old_row < 0) old_row = 0; - else if (old_row >= old_num_rows) old_row = old_num_rows - 1; - const uint8 *old_row_data = - old_data + col_offset + (old_num_cols * old_row); - memcpy(new_row_data, old_row_data, sizeof(uint8) * num_cols); - new_row_data += num_cols; - } - } - - if (num_rows < 8 && format == kOneByteWithColHeaders) { - // format was 1 but we want it to be 2 -> create a temporary - // Matrix (uncompress), re-compress, and swap. - // This gives us almost exact reconstruction while saving - // memory (the elements take more space but there will be - // no per-column headers). - Matrix temp(this->NumRows(), this->NumCols(), - kUndefined); - this->CopyToMat(&temp); - CompressedMatrix temp_cmat(temp, kTwoByteAuto); - this->Swap(&temp_cmat); - } -} - - -template -CompressedMatrix &CompressedMatrix::operator =(const MatrixBase &mat) { - this->CopyFromMat(mat); - return *this; -} - -// Instantiate the template for float and double. -template -CompressedMatrix& CompressedMatrix::operator =(const MatrixBase &mat); - -template -CompressedMatrix& CompressedMatrix::operator =(const MatrixBase &mat); - -inline uint16 CompressedMatrix::FloatToUint16( - const GlobalHeader &global_header, - float value) { - float f = (value - global_header.min_value) / - global_header.range; - if (f > 1.0) f = 1.0; // Note: this should not happen. - if (f < 0.0) f = 0.0; // Note: this should not happen. - return static_cast(f * 65535 + 0.499); // + 0.499 is to - // round to closest int; avoids bias. -} - - -inline uint8 CompressedMatrix::FloatToUint8( - const GlobalHeader &global_header, - float value) { - float f = (value - global_header.min_value) / - global_header.range; - if (f > 1.0) f = 1.0; // Note: this should not happen. - if (f < 0.0) f = 0.0; // Note: this should not happen. - return static_cast(f * 255 + 0.499); // + 0.499 is to - // round to closest int; avoids bias. -} - - -inline float CompressedMatrix::Uint16ToFloat( - const GlobalHeader &global_header, - uint16 value) { - // the constant 1.52590218966964e-05 is 1/65535. - return global_header.min_value - + global_header.range * 1.52590218966964e-05F * value; -} - -template // static -void CompressedMatrix::ComputeColHeader( - const GlobalHeader &global_header, - const Real *data, MatrixIndexT stride, - int32 num_rows, CompressedMatrix::PerColHeader *header) { - KALDI_ASSERT(num_rows > 0); - std::vector sdata(num_rows); // the sorted data. - for (size_t i = 0, size = sdata.size(); i < size; i++) - sdata[i] = data[i*stride]; - - if (num_rows >= 5) { - int quarter_nr = num_rows/4; - // std::sort(sdata.begin(), sdata.end()); - // The elements at positions 0, quarter_nr, - // 3*quarter_nr, and num_rows-1 need to be in sorted order. - std::nth_element(sdata.begin(), sdata.begin() + quarter_nr, sdata.end()); - // Now, sdata.begin() + quarter_nr contains the element that would appear - // in sorted order, in that position. - std::nth_element(sdata.begin(), sdata.begin(), sdata.begin() + quarter_nr); - // Now, sdata.begin() and sdata.begin() + quarter_nr contain the elements - // that would appear at those positions in sorted order. - std::nth_element(sdata.begin() + quarter_nr + 1, - sdata.begin() + (3*quarter_nr), sdata.end()); - // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + - // 3*quarter_nr, contain the elements that would appear at those positions - // in sorted order. - std::nth_element(sdata.begin() + (3*quarter_nr) + 1, sdata.end() - 1, - sdata.end()); - // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + - // 3*quarter_nr, and sdata.end() - 1, contain the elements that would appear - // at those positions in sorted order. - - header->percentile_0 = - std::min(FloatToUint16(global_header, sdata[0]), 65532); - header->percentile_25 = - std::min( - std::max( - FloatToUint16(global_header, sdata[quarter_nr]), - header->percentile_0 + static_cast(1)), 65533); - header->percentile_75 = - std::min( - std::max( - FloatToUint16(global_header, sdata[3*quarter_nr]), - header->percentile_25 + static_cast(1)), 65534); - header->percentile_100 = std::max( - FloatToUint16(global_header, sdata[num_rows-1]), - header->percentile_75 + static_cast(1)); - - } else { // handle this pathological case. - std::sort(sdata.begin(), sdata.end()); - // Note: we know num_rows is at least 1. - header->percentile_0 = - std::min(FloatToUint16(global_header, sdata[0]), - 65532); - if (num_rows > 1) - header->percentile_25 = - std::min( - std::max(FloatToUint16(global_header, sdata[1]), - header->percentile_0 + 1), 65533); - else - header->percentile_25 = header->percentile_0 + 1; - if (num_rows > 2) - header->percentile_75 = - std::min( - std::max(FloatToUint16(global_header, sdata[2]), - header->percentile_25 + 1), 65534); - else - header->percentile_75 = header->percentile_25 + 1; - if (num_rows > 3) - header->percentile_100 = - std::max(FloatToUint16(global_header, sdata[3]), - header->percentile_75 + 1); - else - header->percentile_100 = header->percentile_75 + 1; - } -} - -// static -inline uint8 CompressedMatrix::FloatToChar( - float p0, float p25, float p75, float p100, - float value) { - int ans; - if (value < p25) { // range [ p0, p25 ) covered by - // characters 0 .. 64. We round to the closest int. - float f = (value - p0) / (p25 - p0); - ans = static_cast(f * 64 + 0.5); - // Note: the checks on the next two lines - // are necessary in pathological cases when all the elements in a row - // are the same and the percentile_* values are separated by one. - if (ans < 0) ans = 0; - if (ans > 64) ans = 64; - } else if (value < p75) { // range [ p25, p75 )covered - // by characters 64 .. 192. We round to the closest int. - float f = (value - p25) / (p75 - p25); - ans = 64 + static_cast(f * 128 + 0.5); - if (ans < 64) ans = 64; - if (ans > 192) ans = 192; - } else { // range [ p75, p100 ] covered by - // characters 192 .. 255. Note: this last range - // has fewer characters than the left range, because - // we go up to 255, not 256. - float f = (value - p75) / (p100 - p75); - ans = 192 + static_cast(f * 63 + 0.5); - if (ans < 192) ans = 192; - if (ans > 255) ans = 255; - } - return static_cast(ans); -} - - -// static -inline float CompressedMatrix::CharToFloat( - float p0, float p25, float p75, float p100, - uint8 value) { - if (value <= 64) { - return p0 + (p25 - p0) * value * (1/64.0); - } else if (value <= 192) { - return p25 + (p75 - p25) * (value - 64) * (1/128.0); - } else { - return p75 + (p100 - p75) * (value - 192) * (1/63.0); - } -} - - -template // static -void CompressedMatrix::CompressColumn( - const GlobalHeader &global_header, - const Real *data, MatrixIndexT stride, - int32 num_rows, CompressedMatrix::PerColHeader *header, - uint8 *byte_data) { - ComputeColHeader(global_header, data, stride, - num_rows, header); - - float p0 = Uint16ToFloat(global_header, header->percentile_0), - p25 = Uint16ToFloat(global_header, header->percentile_25), - p75 = Uint16ToFloat(global_header, header->percentile_75), - p100 = Uint16ToFloat(global_header, header->percentile_100); - - for (int32 i = 0; i < num_rows; i++) { - Real this_data = data[i * stride]; - byte_data[i] = FloatToChar(p0, p25, p75, p100, this_data); - } -} - -// static -void* CompressedMatrix::AllocateData(int32 num_bytes) { - KALDI_ASSERT(num_bytes > 0); - KALDI_COMPILE_TIME_ASSERT(sizeof(float) == 4); - // round size up to nearest number of floats. - return reinterpret_cast(new float[(num_bytes/3) + 4]); -} - -void CompressedMatrix::Write(std::ostream &os, bool binary) const { - if (binary) { // Binary-mode write: - if (data_ != NULL) { - GlobalHeader &h = *reinterpret_cast(data_); - DataFormat format = static_cast(h.format); - if (format == kOneByteWithColHeaders) { - WriteToken(os, binary, "CM"); - } else if (format == kTwoByte) { - WriteToken(os, binary, "CM2"); - } else if (format == kOneByte) { - WriteToken(os, binary, "CM3"); - } - MatrixIndexT size = DataSize(h); // total size of data in data_ - // We don't write out the "int32 format", hence the + 4, - 4. - os.write(reinterpret_cast(data_) + 4, size - 4); - } else { // special case: where data_ == NULL, we treat it as an empty - // matrix. - WriteToken(os, binary, "CM"); - GlobalHeader h; - h.range = h.min_value = 0.0; - h.num_rows = h.num_cols = 0; - os.write(reinterpret_cast(&h), sizeof(h)); - } - } else { - // In text mode, just use the same format as a regular matrix. - // This is not compressed. - Matrix temp_mat(this->NumRows(), this->NumCols(), - kUndefined); - this->CopyToMat(&temp_mat); - temp_mat.Write(os, binary); - } - if (os.fail()) - KALDI_ERR << "Error writing compressed matrix to stream."; -} - -void CompressedMatrix::Read(std::istream &is, bool binary) { - if (data_ != NULL) { - delete [] (static_cast(data_)); - data_ = NULL; - } - if (binary) { - int peekval = Peek(is, binary); - if (peekval == 'C') { - std::string tok; // Should be CM (format 1) or CM2 (format 2) - ReadToken(is, binary, &tok); - GlobalHeader h; - if (tok == "CM") { h.format = 1; } // kOneByteWithColHeaders - else if (tok == "CM2") { h.format = 2; } // kTwoByte - else if (tok == "CM3") { h.format = 3; } // kOneByte - else { - KALDI_ERR << "Unexpected token " << tok << ", expecting CM, CM2 or CM3"; - } - // don't read the "format" -> hence + 4, - 4. - is.read(reinterpret_cast(&h) + 4, sizeof(h) - 4); - if (is.fail()) - KALDI_ERR << "Failed to read header"; - if (h.num_cols == 0) // empty matrix. - return; - int32 size = DataSize(h), remaining_size = size - sizeof(GlobalHeader); - data_ = AllocateData(size); - *(reinterpret_cast(data_)) = h; - is.read(reinterpret_cast(data_) + sizeof(GlobalHeader), - remaining_size); - } else { - // Assume that what we're reading is a regular Matrix. This might be the - // case if you changed your code, making a Matrix into a CompressedMatrix, - // and you want back-compatibility for reading. - Matrix M; - M.Read(is, binary); // This will crash if it was not a Matrix. - this->CopyFromMat(M); - } - } else { // Text-mode read. In this case you don't get to - // choose the compression type. Anyway this branch would only - // be taken when debugging. - Matrix temp; - temp.Read(is, binary); - this->CopyFromMat(temp); - } - if (is.fail()) - KALDI_ERR << "Failed to read data."; -} - -template -void CompressedMatrix::CopyToMat(MatrixBase *mat, - MatrixTransposeType trans) const { - if (trans == kTrans) { - Matrix temp(this->NumCols(), this->NumRows()); - CopyToMat(&temp, kNoTrans); - mat->CopyFromMat(temp, kTrans); - return; - } - - if (data_ == NULL) { - KALDI_ASSERT(mat->NumRows() == 0); - KALDI_ASSERT(mat->NumCols() == 0); - return; - } - GlobalHeader *h = reinterpret_cast(data_); - int32 num_cols = h->num_cols, num_rows = h->num_rows; - KALDI_ASSERT(mat->NumRows() == num_rows); - KALDI_ASSERT(mat->NumCols() == num_cols); - - DataFormat format = static_cast(h->format); - if (format == kOneByteWithColHeaders) { - PerColHeader *per_col_header = reinterpret_cast(h+1); - uint8 *byte_data = reinterpret_cast(per_col_header + - h->num_cols); - for (int32 i = 0; i < num_cols; i++, per_col_header++) { - float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), - p25 = Uint16ToFloat(*h, per_col_header->percentile_25), - p75 = Uint16ToFloat(*h, per_col_header->percentile_75), - p100 = Uint16ToFloat(*h, per_col_header->percentile_100); - for (int32 j = 0; j < num_rows; j++, byte_data++) { - float f = CharToFloat(p0, p25, p75, p100, *byte_data); - (*mat)(j, i) = f; - } - } - } else if (format == kTwoByte) { - const uint16 *data = reinterpret_cast(h + 1); - float min_value = h->min_value, - increment = h->range * (1.0 / 65535.0); - for (int32 i = 0; i < num_rows; i++) { - Real *row_data = mat->RowData(i); - for (int32 j = 0; j < num_cols; j++) - row_data[j] = min_value + data[j] * increment; - data += num_cols; - } - } else { - KALDI_ASSERT(format == kOneByte); - float min_value = h->min_value, increment = h->range * (1.0 / 255.0); - - const uint8 *data = reinterpret_cast(h + 1); - for (int32 i = 0; i < num_rows; i++) { - Real *row_data = mat->RowData(i); - for (int32 j = 0; j < num_cols; j++) - row_data[j] = min_value + data[j] * increment; - data += num_cols; - } - } -} - -// Instantiate the template for float and double. -template -void CompressedMatrix::CopyToMat(MatrixBase *mat, - MatrixTransposeType trans) const; -template -void CompressedMatrix::CopyToMat(MatrixBase *mat, - MatrixTransposeType trans) const; - -template -void CompressedMatrix::CopyRowToVec(MatrixIndexT row, - VectorBase *v) const { - KALDI_ASSERT(row < this->NumRows()); - KALDI_ASSERT(row >= 0); - KALDI_ASSERT(v->Dim() == this->NumCols()); - - GlobalHeader *h = reinterpret_cast(data_); - DataFormat format = static_cast(h->format); - if (format == kOneByteWithColHeaders) { - PerColHeader *per_col_header = reinterpret_cast(h+1); - uint8 *byte_data = reinterpret_cast(per_col_header + - h->num_cols); - byte_data += row; // point to first value we are interested in - for (int32 i = 0; i < h->num_cols; - i++, per_col_header++, byte_data += h->num_rows) { - float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), - p25 = Uint16ToFloat(*h, per_col_header->percentile_25), - p75 = Uint16ToFloat(*h, per_col_header->percentile_75), - p100 = Uint16ToFloat(*h, per_col_header->percentile_100); - float f = CharToFloat(p0, p25, p75, p100, *byte_data); - (*v)(i) = f; - } - } else if (format == kTwoByte) { - int32 num_cols = h->num_cols; - float min_value = h->min_value, - increment = h->range * (1.0 / 65535.0); - const uint16 *row_data = reinterpret_cast(h + 1) + (num_cols * row); - Real *v_data = v->Data(); - for (int32 c = 0; c < num_cols; c++) - v_data[c] = min_value + row_data[c] * increment; - } else { - KALDI_ASSERT(format == kOneByte); - int32 num_cols = h->num_cols; - float min_value = h->min_value, - increment = h->range * (1.0 / 255.0); - const uint8 *row_data = reinterpret_cast(h + 1) + (num_cols * row); - Real *v_data = v->Data(); - for (int32 c = 0; c < num_cols; c++) - v_data[c] = min_value + row_data[c] * increment; - } -} - -template -void CompressedMatrix::CopyColToVec(MatrixIndexT col, - VectorBase *v) const { - KALDI_ASSERT(col < this->NumCols()); - KALDI_ASSERT(col >= 0); - KALDI_ASSERT(v->Dim() == this->NumRows()); - - GlobalHeader *h = reinterpret_cast(data_); - - DataFormat format = static_cast(h->format); - if (format == kOneByteWithColHeaders) { - PerColHeader *per_col_header = reinterpret_cast(h+1); - uint8 *byte_data = reinterpret_cast(per_col_header + - h->num_cols); - byte_data += col*h->num_rows; // point to first value in the column we want - per_col_header += col; - float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), - p25 = Uint16ToFloat(*h, per_col_header->percentile_25), - p75 = Uint16ToFloat(*h, per_col_header->percentile_75), - p100 = Uint16ToFloat(*h, per_col_header->percentile_100); - for (int32 i = 0; i < h->num_rows; i++, byte_data++) { - float f = CharToFloat(p0, p25, p75, p100, *byte_data); - (*v)(i) = f; - } - } else if (format == kTwoByte) { - int32 num_rows = h->num_rows, num_cols = h->num_cols; - float min_value = h->min_value, - increment = h->range * (1.0 / 65535.0); - const uint16 *col_data = reinterpret_cast(h + 1) + col; - Real *v_data = v->Data(); - for (int32 r = 0; r < num_rows; r++) - v_data[r] = min_value + increment * col_data[r * num_cols]; - } else { - KALDI_ASSERT(format == kOneByte); - int32 num_rows = h->num_rows, num_cols = h->num_cols; - float min_value = h->min_value, - increment = h->range * (1.0 / 255.0); - const uint8 *col_data = reinterpret_cast(h + 1) + col; - Real *v_data = v->Data(); - for (int32 r = 0; r < num_rows; r++) - v_data[r] = min_value + increment * col_data[r * num_cols]; - } -} - -// instantiate the templates. -template void -CompressedMatrix::CopyColToVec(MatrixIndexT, VectorBase *) const; -template void -CompressedMatrix::CopyColToVec(MatrixIndexT, VectorBase *) const; -template void -CompressedMatrix::CopyRowToVec(MatrixIndexT, VectorBase *) const; -template void -CompressedMatrix::CopyRowToVec(MatrixIndexT, VectorBase *) const; - -template -void CompressedMatrix::CopyToMat(int32 row_offset, - int32 col_offset, - MatrixBase *dest) const { - KALDI_PARANOID_ASSERT(row_offset < this->NumRows()); - KALDI_PARANOID_ASSERT(col_offset < this->NumCols()); - KALDI_PARANOID_ASSERT(row_offset >= 0); - KALDI_PARANOID_ASSERT(col_offset >= 0); - KALDI_ASSERT(row_offset+dest->NumRows() <= this->NumRows()); - KALDI_ASSERT(col_offset+dest->NumCols() <= this->NumCols()); - // everything is OK - GlobalHeader *h = reinterpret_cast(data_); - int32 num_rows = h->num_rows, num_cols = h->num_cols, - tgt_cols = dest->NumCols(), tgt_rows = dest->NumRows(); - - DataFormat format = static_cast(h->format); - if (format == kOneByteWithColHeaders) { - PerColHeader *per_col_header = reinterpret_cast(h+1); - uint8 *byte_data = reinterpret_cast(per_col_header + - h->num_cols); - - uint8 *start_of_subcol = byte_data+row_offset; // skip appropriate - // number of columns - start_of_subcol += col_offset*num_rows; // skip appropriate number of rows - - per_col_header += col_offset; // skip the appropriate number of headers - - for (int32 i = 0; - i < tgt_cols; - i++, per_col_header++, start_of_subcol+=num_rows) { - byte_data = start_of_subcol; - float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), - p25 = Uint16ToFloat(*h, per_col_header->percentile_25), - p75 = Uint16ToFloat(*h, per_col_header->percentile_75), - p100 = Uint16ToFloat(*h, per_col_header->percentile_100); - for (int32 j = 0; j < tgt_rows; j++, byte_data++) { - float f = CharToFloat(p0, p25, p75, p100, *byte_data); - (*dest)(j, i) = f; - } - } - } else if (format == kTwoByte) { - const uint16 *data = reinterpret_cast(h+1) + col_offset + - (num_cols * row_offset); - float min_value = h->min_value, - increment = h->range * (1.0 / 65535.0); - - for (int32 row = 0; row < tgt_rows; row++) { - Real *dest_row = dest->RowData(row); - for (int32 col = 0; col < tgt_cols; col++) - dest_row[col] = min_value + increment * data[col]; - data += num_cols; - } - } else { - KALDI_ASSERT(format == kOneByte); - const uint8 *data = reinterpret_cast(h+1) + col_offset + - (num_cols * row_offset); - float min_value = h->min_value, - increment = h->range * (1.0 / 255.0); - for (int32 row = 0; row < tgt_rows; row++) { - Real *dest_row = dest->RowData(row); - for (int32 col = 0; col < tgt_cols; col++) - dest_row[col] = min_value + increment * data[col]; - data += num_cols; - } - } -} - -// instantiate the templates. -template void CompressedMatrix::CopyToMat(int32, - int32, - MatrixBase *dest) const; -template void CompressedMatrix::CopyToMat(int32, - int32, - MatrixBase *dest) const; - -void CompressedMatrix::Clear() { - if (data_ != NULL) { - delete [] static_cast(data_); - data_ = NULL; - } -} - -CompressedMatrix::CompressedMatrix(const CompressedMatrix &mat): data_(NULL) { - *this = mat; // use assignment operator. -} - -CompressedMatrix &CompressedMatrix::operator = (const CompressedMatrix &mat) { - Clear(); // now this->data_ == NULL. - if (mat.data_ != NULL) { - MatrixIndexT data_size = DataSize(*static_cast(mat.data_)); - data_ = AllocateData(data_size); - memcpy(static_cast(data_), - static_cast(mat.data_), - data_size); - } - return *this; -} - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/compressed-matrix.h b/speechx/speechx/kaldi/matrix/compressed-matrix.h deleted file mode 100644 index 78105b9b..00000000 --- a/speechx/speechx/kaldi/matrix/compressed-matrix.h +++ /dev/null @@ -1,283 +0,0 @@ -// matrix/compressed-matrix.h - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey) -// Frantisek Skala, Wei Shi - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_MATRIX_COMPRESSED_MATRIX_H_ -#define KALDI_MATRIX_COMPRESSED_MATRIX_H_ 1 - -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - -/// \addtogroup matrix_group -/// @{ - - - -/* - The enum CompressionMethod is used when creating a CompressedMatrix (a lossily - compressed matrix) from a regular Matrix. It dictates how we choose the - compressed format and how we choose the ranges of floats that are represented - by particular integers. - - kAutomaticMethod = 1 This is the default when you don't specify the - compression method. It is a shorthand for using - kSpeechFeature if the num-rows is more than 8, and - kTwoByteAuto otherwise. - kSpeechFeature = 2 This is the most complicated of the compression methods, - and was designed for speech features which have a roughly - Gaussian distribution with different ranges for each - dimension. Each element is stored in one byte, but there - is an 8-byte header per column; the spacing of the - integer values is not uniform but is in 3 ranges. - kTwoByteAuto = 3 Each element is stored in two bytes as a uint16, with - the representable range of values chosen automatically - with the minimum and maximum elements of the matrix as - its edges. - kTwoByteSignedInteger = 4 - Each element is stored in two bytes as a uint16, with - the representable range of value chosen to coincide with - what you'd get if you stored signed integers, i.e. - [-32768.0, 32767.0]. Suitable for waveform data that - was previously stored as 16-bit PCM. - kOneByteAuto = 5 Each element is stored in one byte as a uint8, with the - representable range of values chosen automatically with - the minimum and maximum elements of the matrix as its - edges. - kOneByteUnsignedInteger = 6 Each element is stored in - one byte as a uint8, with the representable range of - values equal to [0.0, 255.0]. - kOneByteZeroOne = 7 Each element is stored in - one byte as a uint8, with the representable range of - values equal to [0.0, 1.0]. Suitable for image data - that has previously been compressed as int8. - - // We can add new methods here as needed: if they just imply different ways - // of selecting the min_value and range, and a num-bytes = 1 or 2, they will - // be trivial to implement. -*/ -enum CompressionMethod { - kAutomaticMethod = 1, - kSpeechFeature = 2, - kTwoByteAuto = 3, - kTwoByteSignedInteger = 4, - kOneByteAuto = 5, - kOneByteUnsignedInteger = 6, - kOneByteZeroOne = 7 -}; - - -/* - This class does lossy compression of a matrix. It supports various compression - methods, see enum CompressionMethod. -*/ - -class CompressedMatrix { - public: - CompressedMatrix(): data_(NULL) { } - - ~CompressedMatrix() { Clear(); } - - template - explicit CompressedMatrix(const MatrixBase &mat, - CompressionMethod method = kAutomaticMethod): - data_(NULL) { CopyFromMat(mat, method); } - - /// Initializer that can be used to select part of an existing - /// CompressedMatrix without un-compressing and re-compressing (note: unlike - /// similar initializers for class Matrix, it doesn't point to the same memory - /// location). - /// - /// This creates a CompressedMatrix with the size (num_rows, num_cols) - /// starting at (row_offset, col_offset). - /// - /// If you specify allow_padding = true, - /// it is permitted to have row_offset < 0 and - /// row_offset + num_rows > mat.NumRows(), and the result will contain - /// repeats of the first and last rows of 'mat' as necessary. - CompressedMatrix(const CompressedMatrix &mat, - const MatrixIndexT row_offset, - const MatrixIndexT num_rows, - const MatrixIndexT col_offset, - const MatrixIndexT num_cols, - bool allow_padding = false); - - void *Data() const { return this->data_; } - - /// This will resize *this and copy the contents of mat to *this. - template - void CopyFromMat(const MatrixBase &mat, - CompressionMethod method = kAutomaticMethod); - - CompressedMatrix(const CompressedMatrix &mat); - - CompressedMatrix &operator = (const CompressedMatrix &mat); // assignment operator. - - template - CompressedMatrix &operator = (const MatrixBase &mat); // assignment operator. - - /// Copies contents to matrix. Note: mat must have the correct size. - /// The kTrans case uses a temporary. - template - void CopyToMat(MatrixBase *mat, - MatrixTransposeType trans = kNoTrans) const; - - void Write(std::ostream &os, bool binary) const; - - void Read(std::istream &is, bool binary); - - /// Returns number of rows (or zero for emtpy matrix). - inline MatrixIndexT NumRows() const { return (data_ == NULL) ? 0 : - (*reinterpret_cast(data_)).num_rows; } - - /// Returns number of columns (or zero for emtpy matrix). - inline MatrixIndexT NumCols() const { return (data_ == NULL) ? 0 : - (*reinterpret_cast(data_)).num_cols; } - - /// Copies row #row of the matrix into vector v. - /// Note: v must have same size as #cols. - template - void CopyRowToVec(MatrixIndexT row, VectorBase *v) const; - - /// Copies column #col of the matrix into vector v. - /// Note: v must have same size as #rows. - template - void CopyColToVec(MatrixIndexT col, VectorBase *v) const; - - /// Copies submatrix of compressed matrix into matrix dest. - /// Submatrix starts at row row_offset and column column_offset and its size - /// is defined by size of provided matrix dest - template - void CopyToMat(int32 row_offset, - int32 column_offset, - MatrixBase *dest) const; - - void Swap(CompressedMatrix *other) { std::swap(data_, other->data_); } - - void Clear(); - - /// scales all elements of matrix by alpha. - /// It scales the floating point values in GlobalHeader by alpha. - void Scale(float alpha); - - friend class Matrix; - friend class Matrix; - private: - - // This enum describes the different compressed-data formats: these are - // distinct from the compression methods although all of the methods apart - // from kAutomaticMethod dictate a particular compressed-data format. - // - // kOneByteWithColHeaders means there is a GlobalHeader and each - // column has a PerColHeader; the actual data is stored in - // one byte per element, in column-major order (the mapping - // from integers to floats is a little complicated). - // kTwoByte means there is a global header but no PerColHeader; - // the actual data is stored in two bytes per element in - // row-major order; it's decompressed as: - // uint16 i; GlobalHeader g; - // float f = g.min_value + i * (g.range / 65535.0) - // kOneByte means there is a global header but not PerColHeader; - // the data is stored in one byte per element in row-major - // order and is decompressed as: - // uint8 i; GlobalHeader g; - // float f = g.min_value + i * (g.range / 255.0) - enum DataFormat { - kOneByteWithColHeaders = 1, - kTwoByte = 2, - kOneByte = 3 - }; - - - // allocates data using new [], ensures byte alignment - // sufficient for float. - static void *AllocateData(int32 num_bytes); - - struct GlobalHeader { - int32 format; // Represents the enum DataFormat. - float min_value; // min_value and range represent the ranges of the integer - // data in the kTwoByte and kOneByte formats, and the - // range of the PerColHeader uint16's in the - // kOneByteWithColheaders format. - float range; - int32 num_rows; - int32 num_cols; - }; - - // This function computes the global header for compressing this data. - template - static inline void ComputeGlobalHeader(const MatrixBase &mat, - CompressionMethod method, - GlobalHeader *header); - - - // The number of bytes we need to request when allocating 'data_'. - static MatrixIndexT DataSize(const GlobalHeader &header); - - // This struct is only used in format kOneByteWithColHeaders. - struct PerColHeader { - uint16 percentile_0; - uint16 percentile_25; - uint16 percentile_75; - uint16 percentile_100; - }; - - template - static void CompressColumn(const GlobalHeader &global_header, - const Real *data, MatrixIndexT stride, - int32 num_rows, PerColHeader *header, - uint8 *byte_data); - template - static void ComputeColHeader(const GlobalHeader &global_header, - const Real *data, MatrixIndexT stride, - int32 num_rows, PerColHeader *header); - - static inline uint16 FloatToUint16(const GlobalHeader &global_header, - float value); - - // this is used only in the kOneByte compression format. - static inline uint8 FloatToUint8(const GlobalHeader &global_header, - float value); - - static inline float Uint16ToFloat(const GlobalHeader &global_header, - uint16 value); - - // this is used only in the kOneByteWithColHeaders compression format. - static inline uint8 FloatToChar(float p0, float p25, - float p75, float p100, - float value); - - // this is used only in the kOneByteWithColHeaders compression format. - static inline float CharToFloat(float p0, float p25, - float p75, float p100, - uint8 value); - - void *data_; // first GlobalHeader, then PerColHeader (repeated), then - // the byte data for each column (repeated). Note: don't intersperse - // the byte data with the PerColHeaders, because of alignment issues. - -}; - -/// @} end of \addtogroup matrix_group - - -} // namespace kaldi - - -#endif // KALDI_MATRIX_COMPRESSED_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/jama-eig.h b/speechx/speechx/kaldi/matrix/jama-eig.h deleted file mode 100644 index 92d8c27e..00000000 --- a/speechx/speechx/kaldi/matrix/jama-eig.h +++ /dev/null @@ -1,924 +0,0 @@ -// matrix/jama-eig.h - -// Copyright 2009-2011 Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -// This file consists of a port and modification of materials from -// JAMA: A Java Matrix Package -// under the following notice: This software is a cooperative product of -// The MathWorks and the National Institute of Standards and Technology (NIST) -// which has been released to the public. This notice and the original code are -// available at http://math.nist.gov/javanumerics/jama/domain.notice - - - -#ifndef KALDI_MATRIX_JAMA_EIG_H_ -#define KALDI_MATRIX_JAMA_EIG_H_ 1 - -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - -// This class is not to be used externally. See the Eig function in the Matrix -// class in kaldi-matrix.h. This is the external interface. - -template class EigenvalueDecomposition { - // This class is based on the EigenvalueDecomposition class from the JAMA - // library (version 1.0.2). - public: - EigenvalueDecomposition(const MatrixBase &A); - - ~EigenvalueDecomposition(); // free memory. - - void GetV(MatrixBase *V_out) { // V is what we call P externally; it's the matrix of - // eigenvectors. - KALDI_ASSERT(V_out->NumRows() == static_cast(n_) - && V_out->NumCols() == static_cast(n_)); - for (int i = 0; i < n_; i++) - for (int j = 0; j < n_; j++) - (*V_out)(i, j) = V(i, j); // V(i, j) is member function. - } - void GetRealEigenvalues(VectorBase *r_out) { - // returns real part of eigenvalues. - KALDI_ASSERT(r_out->Dim() == static_cast(n_)); - for (int i = 0; i < n_; i++) - (*r_out)(i) = d_[i]; - } - void GetImagEigenvalues(VectorBase *i_out) { - // returns imaginary part of eigenvalues. - KALDI_ASSERT(i_out->Dim() == static_cast(n_)); - for (int i = 0; i < n_; i++) - (*i_out)(i) = e_[i]; - } - private: - - inline Real &H(int r, int c) { return H_[r*n_ + c]; } - inline Real &V(int r, int c) { return V_[r*n_ + c]; } - - // complex division - inline static void cdiv(Real xr, Real xi, Real yr, Real yi, Real *cdivr, Real *cdivi) { - Real r, d; - if (std::abs(yr) > std::abs(yi)) { - r = yi/yr; - d = yr + r*yi; - *cdivr = (xr + r*xi)/d; - *cdivi = (xi - r*xr)/d; - } else { - r = yr/yi; - d = yi + r*yr; - *cdivr = (r*xr + xi)/d; - *cdivi = (r*xi - xr)/d; - } - } - - // Nonsymmetric reduction from Hessenberg to real Schur form. - void Hqr2 (); - - - int n_; // matrix dimension. - - Real *d_, *e_; // real and imaginary parts of eigenvalues. - Real *V_; // the eigenvectors (P in our external notation) - Real *H_; // the nonsymmetric Hessenberg form. - Real *ort_; // working storage for nonsymmetric algorithm. - - // Symmetric Householder reduction to tridiagonal form. - void Tred2 (); - - // Symmetric tridiagonal QL algorithm. - void Tql2 (); - - // Nonsymmetric reduction to Hessenberg form. - void Orthes (); - -}; - -template class EigenvalueDecomposition; // force instantiation. -template class EigenvalueDecomposition; // force instantiation. - -template void EigenvalueDecomposition::Tred2() { - // This is derived from the Algol procedures tred2 by - // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for - // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - for (int j = 0; j < n_; j++) { - d_[j] = V(n_-1, j); - } - - // Householder reduction to tridiagonal form. - - for (int i = n_-1; i > 0; i--) { - - // Scale to avoid under/overflow. - - Real scale = 0.0; - Real h = 0.0; - for (int k = 0; k < i; k++) { - scale = scale + std::abs(d_[k]); - } - if (scale == 0.0) { - e_[i] = d_[i-1]; - for (int j = 0; j < i; j++) { - d_[j] = V(i-1, j); - V(i, j) = 0.0; - V(j, i) = 0.0; - } - } else { - - // Generate Householder vector. - - for (int k = 0; k < i; k++) { - d_[k] /= scale; - h += d_[k] * d_[k]; - } - Real f = d_[i-1]; - Real g = std::sqrt(h); - if (f > 0) { - g = -g; - } - e_[i] = scale * g; - h = h - f * g; - d_[i-1] = f - g; - for (int j = 0; j < i; j++) { - e_[j] = 0.0; - } - - // Apply similarity transformation to remaining columns. - - for (int j = 0; j < i; j++) { - f = d_[j]; - V(j, i) = f; - g =e_[j] + V(j, j) * f; - for (int k = j+1; k <= i-1; k++) { - g += V(k, j) * d_[k]; - e_[k] += V(k, j) * f; - } - e_[j] = g; - } - f = 0.0; - for (int j = 0; j < i; j++) { - e_[j] /= h; - f += e_[j] * d_[j]; - } - Real hh = f / (h + h); - for (int j = 0; j < i; j++) { - e_[j] -= hh * d_[j]; - } - for (int j = 0; j < i; j++) { - f = d_[j]; - g = e_[j]; - for (int k = j; k <= i-1; k++) { - V(k, j) -= (f * e_[k] + g * d_[k]); - } - d_[j] = V(i-1, j); - V(i, j) = 0.0; - } - } - d_[i] = h; - } - - // Accumulate transformations. - - for (int i = 0; i < n_-1; i++) { - V(n_-1, i) = V(i, i); - V(i, i) = 1.0; - Real h = d_[i+1]; - if (h != 0.0) { - for (int k = 0; k <= i; k++) { - d_[k] = V(k, i+1) / h; - } - for (int j = 0; j <= i; j++) { - Real g = 0.0; - for (int k = 0; k <= i; k++) { - g += V(k, i+1) * V(k, j); - } - for (int k = 0; k <= i; k++) { - V(k, j) -= g * d_[k]; - } - } - } - for (int k = 0; k <= i; k++) { - V(k, i+1) = 0.0; - } - } - for (int j = 0; j < n_; j++) { - d_[j] = V(n_-1, j); - V(n_-1, j) = 0.0; - } - V(n_-1, n_-1) = 1.0; - e_[0] = 0.0; -} - -template void EigenvalueDecomposition::Tql2() { - // This is derived from the Algol procedures tql2, by - // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for - // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - for (int i = 1; i < n_; i++) { - e_[i-1] = e_[i]; - } - e_[n_-1] = 0.0; - - Real f = 0.0; - Real tst1 = 0.0; - Real eps = std::numeric_limits::epsilon(); - for (int l = 0; l < n_; l++) { - - // Find small subdiagonal element - - tst1 = std::max(tst1, std::abs(d_[l]) + std::abs(e_[l])); - int m = l; - while (m < n_) { - if (std::abs(e_[m]) <= eps*tst1) { - break; - } - m++; - } - - // If m == l, d_[l] is an eigenvalue, - // otherwise, iterate. - - if (m > l) { - int iter = 0; - do { - iter = iter + 1; // (Could check iteration count here.) - - // Compute implicit shift - - Real g = d_[l]; - Real p = (d_[l+1] - g) / (2.0 *e_[l]); - Real r = Hypot(p, static_cast(1.0)); // This is a Kaldi version of hypot that works with templates. - if (p < 0) { - r = -r; - } - d_[l] =e_[l] / (p + r); - d_[l+1] =e_[l] * (p + r); - Real dl1 = d_[l+1]; - Real h = g - d_[l]; - for (int i = l+2; i < n_; i++) { - d_[i] -= h; - } - f = f + h; - - // Implicit QL transformation. - - p = d_[m]; - Real c = 1.0; - Real c2 = c; - Real c3 = c; - Real el1 =e_[l+1]; - Real s = 0.0; - Real s2 = 0.0; - for (int i = m-1; i >= l; i--) { - c3 = c2; - c2 = c; - s2 = s; - g = c *e_[i]; - h = c * p; - r = Hypot(p, e_[i]); // This is a Kaldi version of Hypot that works with templates. - e_[i+1] = s * r; - s =e_[i] / r; - c = p / r; - p = c * d_[i] - s * g; - d_[i+1] = h + s * (c * g + s * d_[i]); - - // Accumulate transformation. - - for (int k = 0; k < n_; k++) { - h = V(k, i+1); - V(k, i+1) = s * V(k, i) + c * h; - V(k, i) = c * V(k, i) - s * h; - } - } - p = -s * s2 * c3 * el1 *e_[l] / dl1; - e_[l] = s * p; - d_[l] = c * p; - - // Check for convergence. - - } while (std::abs(e_[l]) > eps*tst1); - } - d_[l] = d_[l] + f; - e_[l] = 0.0; - } - - // Sort eigenvalues and corresponding vectors. - - for (int i = 0; i < n_-1; i++) { - int k = i; - Real p = d_[i]; - for (int j = i+1; j < n_; j++) { - if (d_[j] < p) { - k = j; - p = d_[j]; - } - } - if (k != i) { - d_[k] = d_[i]; - d_[i] = p; - for (int j = 0; j < n_; j++) { - p = V(j, i); - V(j, i) = V(j, k); - V(j, k) = p; - } - } - } -} - -template -void EigenvalueDecomposition::Orthes() { - - // This is derived from the Algol procedures orthes and ortran, - // by Martin and Wilkinson, Handbook for Auto. Comp., - // Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutines in EISPACK. - - int low = 0; - int high = n_-1; - - for (int m = low+1; m <= high-1; m++) { - - // Scale column. - - Real scale = 0.0; - for (int i = m; i <= high; i++) { - scale = scale + std::abs(H(i, m-1)); - } - if (scale != 0.0) { - - // Compute Householder transformation. - - Real h = 0.0; - for (int i = high; i >= m; i--) { - ort_[i] = H(i, m-1)/scale; - h += ort_[i] * ort_[i]; - } - Real g = std::sqrt(h); - if (ort_[m] > 0) { - g = -g; - } - h = h - ort_[m] * g; - ort_[m] = ort_[m] - g; - - // Apply Householder similarity transformation - // H = (I-u*u'/h)*H*(I-u*u')/h) - - for (int j = m; j < n_; j++) { - Real f = 0.0; - for (int i = high; i >= m; i--) { - f += ort_[i]*H(i, j); - } - f = f/h; - for (int i = m; i <= high; i++) { - H(i, j) -= f*ort_[i]; - } - } - - for (int i = 0; i <= high; i++) { - Real f = 0.0; - for (int j = high; j >= m; j--) { - f += ort_[j]*H(i, j); - } - f = f/h; - for (int j = m; j <= high; j++) { - H(i, j) -= f*ort_[j]; - } - } - ort_[m] = scale*ort_[m]; - H(m, m-1) = scale*g; - } - } - - // Accumulate transformations (Algol's ortran). - - for (int i = 0; i < n_; i++) { - for (int j = 0; j < n_; j++) { - V(i, j) = (i == j ? 1.0 : 0.0); - } - } - - for (int m = high-1; m >= low+1; m--) { - if (H(m, m-1) != 0.0) { - for (int i = m+1; i <= high; i++) { - ort_[i] = H(i, m-1); - } - for (int j = m; j <= high; j++) { - Real g = 0.0; - for (int i = m; i <= high; i++) { - g += ort_[i] * V(i, j); - } - // Double division avoids possible underflow - g = (g / ort_[m]) / H(m, m-1); - for (int i = m; i <= high; i++) { - V(i, j) += g * ort_[i]; - } - } - } - } -} - -template void EigenvalueDecomposition::Hqr2() { - // This is derived from the Algol procedure hqr2, - // by Martin and Wilkinson, Handbook for Auto. Comp., - // Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - int nn = n_; - int n = nn-1; - int low = 0; - int high = nn-1; - Real eps = std::numeric_limits::epsilon(); - Real exshift = 0.0; - Real p = 0, q = 0, r = 0, s = 0, z=0, t, w, x, y; - - // Store roots isolated by balanc and compute matrix norm - - Real norm = 0.0; - for (int i = 0; i < nn; i++) { - if (i < low || i > high) { - d_[i] = H(i, i); - e_[i] = 0.0; - } - for (int j = std::max(i-1, 0); j < nn; j++) { - norm = norm + std::abs(H(i, j)); - } - } - - // Outer loop over eigenvalue index - - int iter = 0; - while (n >= low) { - - // Look for single small sub-diagonal element - - int l = n; - while (l > low) { - s = std::abs(H(l-1, l-1)) + std::abs(H(l, l)); - if (s == 0.0) { - s = norm; - } - if (std::abs(H(l, l-1)) < eps * s) { - break; - } - l--; - } - - // Check for convergence - // One root found - - if (l == n) { - H(n, n) = H(n, n) + exshift; - d_[n] = H(n, n); - e_[n] = 0.0; - n--; - iter = 0; - - // Two roots found - - } else if (l == n-1) { - w = H(n, n-1) * H(n-1, n); - p = (H(n-1, n-1) - H(n, n)) / 2.0; - q = p * p + w; - z = std::sqrt(std::abs(q)); - H(n, n) = H(n, n) + exshift; - H(n-1, n-1) = H(n-1, n-1) + exshift; - x = H(n, n); - - // Real pair - - if (q >= 0) { - if (p >= 0) { - z = p + z; - } else { - z = p - z; - } - d_[n-1] = x + z; - d_[n] = d_[n-1]; - if (z != 0.0) { - d_[n] = x - w / z; - } - e_[n-1] = 0.0; - e_[n] = 0.0; - x = H(n, n-1); - s = std::abs(x) + std::abs(z); - p = x / s; - q = z / s; - r = std::sqrt(p * p+q * q); - p = p / r; - q = q / r; - - // Row modification - - for (int j = n-1; j < nn; j++) { - z = H(n-1, j); - H(n-1, j) = q * z + p * H(n, j); - H(n, j) = q * H(n, j) - p * z; - } - - // Column modification - - for (int i = 0; i <= n; i++) { - z = H(i, n-1); - H(i, n-1) = q * z + p * H(i, n); - H(i, n) = q * H(i, n) - p * z; - } - - // Accumulate transformations - - for (int i = low; i <= high; i++) { - z = V(i, n-1); - V(i, n-1) = q * z + p * V(i, n); - V(i, n) = q * V(i, n) - p * z; - } - - // Complex pair - - } else { - d_[n-1] = x + p; - d_[n] = x + p; - e_[n-1] = z; - e_[n] = -z; - } - n = n - 2; - iter = 0; - - // No convergence yet - - } else { - - // Form shift - - x = H(n, n); - y = 0.0; - w = 0.0; - if (l < n) { - y = H(n-1, n-1); - w = H(n, n-1) * H(n-1, n); - } - - // Wilkinson's original ad hoc shift - - if (iter == 10) { - exshift += x; - for (int i = low; i <= n; i++) { - H(i, i) -= x; - } - s = std::abs(H(n, n-1)) + std::abs(H(n-1, n-2)); - x = y = 0.75 * s; - w = -0.4375 * s * s; - } - - // MATLAB's new ad hoc shift - - if (iter == 30) { - s = (y - x) / 2.0; - s = s * s + w; - if (s > 0) { - s = std::sqrt(s); - if (y < x) { - s = -s; - } - s = x - w / ((y - x) / 2.0 + s); - for (int i = low; i <= n; i++) { - H(i, i) -= s; - } - exshift += s; - x = y = w = 0.964; - } - } - - iter = iter + 1; // (Could check iteration count here.) - - // Look for two consecutive small sub-diagonal elements - - int m = n-2; - while (m >= l) { - z = H(m, m); - r = x - z; - s = y - z; - p = (r * s - w) / H(m+1, m) + H(m, m+1); - q = H(m+1, m+1) - z - r - s; - r = H(m+2, m+1); - s = std::abs(p) + std::abs(q) + std::abs(r); - p = p / s; - q = q / s; - r = r / s; - if (m == l) { - break; - } - if (std::abs(H(m, m-1)) * (std::abs(q) + std::abs(r)) < - eps * (std::abs(p) * (std::abs(H(m-1, m-1)) + std::abs(z) + - std::abs(H(m+1, m+1))))) { - break; - } - m--; - } - - for (int i = m+2; i <= n; i++) { - H(i, i-2) = 0.0; - if (i > m+2) { - H(i, i-3) = 0.0; - } - } - - // Double QR step involving rows l:n and columns m:n - - for (int k = m; k <= n-1; k++) { - bool notlast = (k != n-1); - if (k != m) { - p = H(k, k-1); - q = H(k+1, k-1); - r = (notlast ? H(k+2, k-1) : 0.0); - x = std::abs(p) + std::abs(q) + std::abs(r); - if (x != 0.0) { - p = p / x; - q = q / x; - r = r / x; - } - } - if (x == 0.0) { - break; - } - s = std::sqrt(p * p + q * q + r * r); - if (p < 0) { - s = -s; - } - if (s != 0) { - if (k != m) { - H(k, k-1) = -s * x; - } else if (l != m) { - H(k, k-1) = -H(k, k-1); - } - p = p + s; - x = p / s; - y = q / s; - z = r / s; - q = q / p; - r = r / p; - - // Row modification - - for (int j = k; j < nn; j++) { - p = H(k, j) + q * H(k+1, j); - if (notlast) { - p = p + r * H(k+2, j); - H(k+2, j) = H(k+2, j) - p * z; - } - H(k, j) = H(k, j) - p * x; - H(k+1, j) = H(k+1, j) - p * y; - } - - // Column modification - - for (int i = 0; i <= std::min(n, k+3); i++) { - p = x * H(i, k) + y * H(i, k+1); - if (notlast) { - p = p + z * H(i, k+2); - H(i, k+2) = H(i, k+2) - p * r; - } - H(i, k) = H(i, k) - p; - H(i, k+1) = H(i, k+1) - p * q; - } - - // Accumulate transformations - - for (int i = low; i <= high; i++) { - p = x * V(i, k) + y * V(i, k+1); - if (notlast) { - p = p + z * V(i, k+2); - V(i, k+2) = V(i, k+2) - p * r; - } - V(i, k) = V(i, k) - p; - V(i, k+1) = V(i, k+1) - p * q; - } - } // (s != 0) - } // k loop - } // check convergence - } // while (n >= low) - - // Backsubstitute to find vectors of upper triangular form - - if (norm == 0.0) { - return; - } - - for (n = nn-1; n >= 0; n--) { - p = d_[n]; - q = e_[n]; - - // Real vector - - if (q == 0) { - int l = n; - H(n, n) = 1.0; - for (int i = n-1; i >= 0; i--) { - w = H(i, i) - p; - r = 0.0; - for (int j = l; j <= n; j++) { - r = r + H(i, j) * H(j, n); - } - if (e_[i] < 0.0) { - z = w; - s = r; - } else { - l = i; - if (e_[i] == 0.0) { - if (w != 0.0) { - H(i, n) = -r / w; - } else { - H(i, n) = -r / (eps * norm); - } - - // Solve real equations - - } else { - x = H(i, i+1); - y = H(i+1, i); - q = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i]; - t = (x * s - z * r) / q; - H(i, n) = t; - if (std::abs(x) > std::abs(z)) { - H(i+1, n) = (-r - w * t) / x; - } else { - H(i+1, n) = (-s - y * t) / z; - } - } - - // Overflow control - - t = std::abs(H(i, n)); - if ((eps * t) * t > 1) { - for (int j = i; j <= n; j++) { - H(j, n) = H(j, n) / t; - } - } - } - } - - // Complex vector - - } else if (q < 0) { - int l = n-1; - - // Last vector component imaginary so matrix is triangular - - if (std::abs(H(n, n-1)) > std::abs(H(n-1, n))) { - H(n-1, n-1) = q / H(n, n-1); - H(n-1, n) = -(H(n, n) - p) / H(n, n-1); - } else { - Real cdivr, cdivi; - cdiv(0.0, -H(n-1, n), H(n-1, n-1)-p, q, &cdivr, &cdivi); - H(n-1, n-1) = cdivr; - H(n-1, n) = cdivi; - } - H(n, n-1) = 0.0; - H(n, n) = 1.0; - for (int i = n-2; i >= 0; i--) { - Real ra, sa, vr, vi; - ra = 0.0; - sa = 0.0; - for (int j = l; j <= n; j++) { - ra = ra + H(i, j) * H(j, n-1); - sa = sa + H(i, j) * H(j, n); - } - w = H(i, i) - p; - - if (e_[i] < 0.0) { - z = w; - r = ra; - s = sa; - } else { - l = i; - if (e_[i] == 0) { - Real cdivr, cdivi; - cdiv(-ra, -sa, w, q, &cdivr, &cdivi); - H(i, n-1) = cdivr; - H(i, n) = cdivi; - } else { - Real cdivr, cdivi; - // Solve complex equations - - x = H(i, i+1); - y = H(i+1, i); - vr = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i] - q * q; - vi = (d_[i] - p) * 2.0 * q; - if (vr == 0.0 && vi == 0.0) { - vr = eps * norm * (std::abs(w) + std::abs(q) + - std::abs(x) + std::abs(y) + std::abs(z)); - } - cdiv(x*r-z*ra+q*sa, x*s-z*sa-q*ra, vr, vi, &cdivr, &cdivi); - H(i, n-1) = cdivr; - H(i, n) = cdivi; - if (std::abs(x) > (std::abs(z) + std::abs(q))) { - H(i+1, n-1) = (-ra - w * H(i, n-1) + q * H(i, n)) / x; - H(i+1, n) = (-sa - w * H(i, n) - q * H(i, n-1)) / x; - } else { - cdiv(-r-y*H(i, n-1), -s-y*H(i, n), z, q, &cdivr, &cdivi); - H(i+1, n-1) = cdivr; - H(i+1, n) = cdivi; - } - } - - // Overflow control - - t = std::max(std::abs(H(i, n-1)), std::abs(H(i, n))); - if ((eps * t) * t > 1) { - for (int j = i; j <= n; j++) { - H(j, n-1) = H(j, n-1) / t; - H(j, n) = H(j, n) / t; - } - } - } - } - } - } - - // Vectors of isolated roots - - for (int i = 0; i < nn; i++) { - if (i < low || i > high) { - for (int j = i; j < nn; j++) { - V(i, j) = H(i, j); - } - } - } - - // Back transformation to get eigenvectors of original matrix - - for (int j = nn-1; j >= low; j--) { - for (int i = low; i <= high; i++) { - z = 0.0; - for (int k = low; k <= std::min(j, high); k++) { - z = z + V(i, k) * H(k, j); - } - V(i, j) = z; - } - } -} - -template -EigenvalueDecomposition::EigenvalueDecomposition(const MatrixBase &A) { - KALDI_ASSERT(A.NumCols() == A.NumRows() && A.NumCols() >= 1); - n_ = A.NumRows(); - V_ = new Real[n_*n_]; - d_ = new Real[n_]; - e_ = new Real[n_]; - H_ = NULL; - ort_ = NULL; - if (A.IsSymmetric(0.0)) { - - for (int i = 0; i < n_; i++) - for (int j = 0; j < n_; j++) - V(i, j) = A(i, j); // Note that V(i, j) is a member function; A(i, j) is an operator - // of the matrix A. - // Tridiagonalize. - Tred2(); - - // Diagonalize. - Tql2(); - } else { - H_ = new Real[n_*n_]; - ort_ = new Real[n_]; - for (int i = 0; i < n_; i++) - for (int j = 0; j < n_; j++) - H(i, j) = A(i, j); // as before: H is member function, A(i, j) is operator of matrix. - - // Reduce to Hessenberg form. - Orthes(); - - // Reduce Hessenberg to real Schur form. - Hqr2(); - } -} - -template -EigenvalueDecomposition::~EigenvalueDecomposition() { - delete [] d_; - delete [] e_; - delete [] V_; - delete [] H_; - delete [] ort_; -} - -// see function MatrixBase::Eig in kaldi-matrix.cc - - -} // namespace kaldi - -#endif // KALDI_MATRIX_JAMA_EIG_H_ diff --git a/speechx/speechx/kaldi/matrix/jama-svd.h b/speechx/speechx/kaldi/matrix/jama-svd.h deleted file mode 100644 index 8304dac6..00000000 --- a/speechx/speechx/kaldi/matrix/jama-svd.h +++ /dev/null @@ -1,531 +0,0 @@ -// matrix/jama-svd.h - -// Copyright 2009-2011 Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -// This file consists of a port and modification of materials from -// JAMA: A Java Matrix Package -// under the following notice: This software is a cooperative product of -// The MathWorks and the National Institute of Standards and Technology (NIST) -// which has been released to the public. This notice and the original code are -// available at http://math.nist.gov/javanumerics/jama/domain.notice - - -#ifndef KALDI_MATRIX_JAMA_SVD_H_ -#define KALDI_MATRIX_JAMA_SVD_H_ 1 - - -#include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/cblas-wrappers.h" - -namespace kaldi { - -#if defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) -// using ATLAS as our math library, which doesn't have SVD -> need -// to implement it. - -// This routine is a modified form of jama_svd.h which is part of the TNT distribution. -// (originally comes from JAMA). - -/** Singular Value Decomposition. - *

- * For an m-by-n matrix A with m >= n, the singular value decomposition is - * an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and - * an n-by-n orthogonal matrix V so that A = U*S*V'. - *

- * The singular values, sigma[k] = S(k, k), are ordered so that - * sigma[0] >= sigma[1] >= ... >= sigma[n-1]. - *

- * The singular value decompostion always exists, so the constructor will - * never fail. The matrix condition number and the effective numerical - * rank can be computed from this decomposition. - - *

- * (Adapted from JAMA, a Java Matrix Library, developed by jointly - * by the Mathworks and NIST; see http://math.nist.gov/javanumerics/jama). - */ - - -template -bool MatrixBase::JamaSvd(VectorBase *s_in, - MatrixBase *U_in, - MatrixBase *V_in) { // Destructive! - KALDI_ASSERT(s_in != NULL && U_in != this && V_in != this); - int wantu = (U_in != NULL), wantv = (V_in != NULL); - Matrix Utmp, Vtmp; - MatrixBase &U = (U_in ? *U_in : Utmp), &V = (V_in ? *V_in : Vtmp); - VectorBase &s = *s_in; - - int m = num_rows_, n = num_cols_; - KALDI_ASSERT(m>=n && m != 0 && n != 0); - if (wantu) KALDI_ASSERT((int)U.num_rows_ == m && (int)U.num_cols_ == n); - if (wantv) KALDI_ASSERT((int)V.num_rows_ == n && (int)V.num_cols_ == n); - KALDI_ASSERT((int)s.Dim() == n); // n<=m so n is min. - - int nu = n; - U.SetZero(); // make sure all zero. - Vector e(n); - Vector work(m); - MatrixBase &A(*this); - Real *adata = A.Data(), *workdata = work.Data(), *edata = e.Data(), - *udata = U.Data(), *vdata = V.Data(); - int astride = static_cast(A.Stride()), - ustride = static_cast(U.Stride()), - vstride = static_cast(V.Stride()); - int i = 0, j = 0, k = 0; - - // Reduce A to bidiagonal form, storing the diagonal elements - // in s and the super-diagonal elements in e. - - int nct = std::min(m-1, n); - int nrt = std::max(0, std::min(n-2, m)); - for (k = 0; k < std::max(nct, nrt); k++) { - if (k < nct) { - - // Compute the transformation for the k-th column and - // place the k-th diagonal in s(k). - // Compute 2-norm of k-th column without under/overflow. - s(k) = 0; - for (i = k; i < m; i++) { - s(k) = hypot(s(k), A(i, k)); - } - if (s(k) != 0.0) { - if (A(k, k) < 0.0) { - s(k) = -s(k); - } - for (i = k; i < m; i++) { - A(i, k) /= s(k); - } - A(k, k) += 1.0; - } - s(k) = -s(k); - } - for (j = k+1; j < n; j++) { - if ((k < nct) && (s(k) != 0.0)) { - - // Apply the transformation. - - Real t = cblas_Xdot(m - k, adata + astride*k + k, astride, - adata + astride*k + j, astride); - /*for (i = k; i < m; i++) { - t += adata[i*astride + k]*adata[i*astride + j]; // A(i, k)*A(i, j); // 3 - }*/ - t = -t/A(k, k); - cblas_Xaxpy(m - k, t, adata + k*astride + k, astride, - adata + k*astride + j, astride); - /*for (i = k; i < m; i++) { - adata[i*astride + j] += t*adata[i*astride + k]; // A(i, j) += t*A(i, k); // 5 - }*/ - } - - // Place the k-th row of A into e for the - // subsequent calculation of the row transformation. - - e(j) = A(k, j); - } - if (wantu & (k < nct)) { - - // Place the transformation in U for subsequent back - // multiplication. - - for (i = k; i < m; i++) { - U(i, k) = A(i, k); - } - } - if (k < nrt) { - - // Compute the k-th row transformation and place the - // k-th super-diagonal in e(k). - // Compute 2-norm without under/overflow. - e(k) = 0; - for (i = k+1; i < n; i++) { - e(k) = hypot(e(k), e(i)); - } - if (e(k) != 0.0) { - if (e(k+1) < 0.0) { - e(k) = -e(k); - } - for (i = k+1; i < n; i++) { - e(i) /= e(k); - } - e(k+1) += 1.0; - } - e(k) = -e(k); - if ((k+1 < m) & (e(k) != 0.0)) { - - // Apply the transformation. - - for (i = k+1; i < m; i++) { - work(i) = 0.0; - } - for (j = k+1; j < n; j++) { - for (i = k+1; i < m; i++) { - workdata[i] += edata[j] * adata[i*astride + j]; // work(i) += e(j)*A(i, j); // 5 - } - } - for (j = k+1; j < n; j++) { - Real t(-e(j)/e(k+1)); - cblas_Xaxpy(m - (k+1), t, workdata + (k+1), 1, - adata + (k+1)*astride + j, astride); - /* - for (i = k+1; i < m; i++) { - adata[i*astride + j] += t*workdata[i]; // A(i, j) += t*work(i); // 5 - }*/ - } - } - if (wantv) { - - // Place the transformation in V for subsequent - // back multiplication. - - for (i = k+1; i < n; i++) { - V(i, k) = e(i); - } - } - } - } - - // Set up the final bidiagonal matrix or order p. - - int p = std::min(n, m+1); - if (nct < n) { - s(nct) = A(nct, nct); - } - if (m < p) { - s(p-1) = 0.0; - } - if (nrt+1 < p) { - e(nrt) = A(nrt, p-1); - } - e(p-1) = 0.0; - - // If required, generate U. - - if (wantu) { - for (j = nct; j < nu; j++) { - for (i = 0; i < m; i++) { - U(i, j) = 0.0; - } - U(j, j) = 1.0; - } - for (k = nct-1; k >= 0; k--) { - if (s(k) != 0.0) { - for (j = k+1; j < nu; j++) { - Real t = cblas_Xdot(m - k, udata + k*ustride + k, ustride, udata + k*ustride + j, ustride); - //for (i = k; i < m; i++) { - // t += udata[i*ustride + k]*udata[i*ustride + j]; // t += U(i, k)*U(i, j); // 8 - // } - t = -t/U(k, k); - cblas_Xaxpy(m - k, t, udata + ustride*k + k, ustride, - udata + k*ustride + j, ustride); - /*for (i = k; i < m; i++) { - udata[i*ustride + j] += t*udata[i*ustride + k]; // U(i, j) += t*U(i, k); // 4 - }*/ - } - for (i = k; i < m; i++ ) { - U(i, k) = -U(i, k); - } - U(k, k) = 1.0 + U(k, k); - for (i = 0; i < k-1; i++) { - U(i, k) = 0.0; - } - } else { - for (i = 0; i < m; i++) { - U(i, k) = 0.0; - } - U(k, k) = 1.0; - } - } - } - - // If required, generate V. - - if (wantv) { - for (k = n-1; k >= 0; k--) { - if ((k < nrt) & (e(k) != 0.0)) { - for (j = k+1; j < nu; j++) { - Real t = cblas_Xdot(n - (k+1), vdata + (k+1)*vstride + k, vstride, - vdata + (k+1)*vstride + j, vstride); - /*Real t (0.0); - for (i = k+1; i < n; i++) { - t += vdata[i*vstride + k]*vdata[i*vstride + j]; // t += V(i, k)*V(i, j); // 7 - }*/ - t = -t/V(k+1, k); - cblas_Xaxpy(n - (k+1), t, vdata + (k+1)*vstride + k, vstride, - vdata + (k+1)*vstride + j, vstride); - /*for (i = k+1; i < n; i++) { - vdata[i*vstride + j] += t*vdata[i*vstride + k]; // V(i, j) += t*V(i, k); // 7 - }*/ - } - } - for (i = 0; i < n; i++) { - V(i, k) = 0.0; - } - V(k, k) = 1.0; - } - } - - // Main iteration loop for the singular values. - - int pp = p-1; - int iter = 0; - // note: -52.0 is from Jama code; the -23 is the extension - // to float, because mantissa length in (double, float) - // is (52, 23) bits respectively. - Real eps(pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); - // Note: the -966 was taken from Jama code, but the -120 is a guess - // of how to extend this to float... the exponent in double goes - // from -1022 .. 1023, and in float from -126..127. I'm not sure - // what the significance of 966 is, so -120 just represents a number - // that's a bit less negative than -126. If we get convergence - // failure in float only, this may mean that we have to make the - // -120 value less negative. - Real tiny(pow(2.0, sizeof(Real) == 4 ? -120.0: -966.0 )); - - while (p > 0) { - int k = 0; - int kase = 0; - - if (iter == 500 || iter == 750) { - KALDI_WARN << "Svd taking a long time: making convergence criterion less exact."; - eps = pow(static_cast(0.8), eps); - tiny = pow(static_cast(0.8), tiny); - } - if (iter > 1000) { - KALDI_WARN << "Svd not converging on matrix of size " << m << " by " <= -1; k--) { - if (k == -1) { - break; - } - if (std::abs(e(k)) <= - tiny + eps*(std::abs(s(k)) + std::abs(s(k+1)))) { - e(k) = 0.0; - break; - } - } - if (k == p-2) { - kase = 4; - } else { - int ks; - for (ks = p-1; ks >= k; ks--) { - if (ks == k) { - break; - } - Real t( (ks != p ? std::abs(e(ks)) : 0.) + - (ks != k+1 ? std::abs(e(ks-1)) : 0.)); - if (std::abs(s(ks)) <= tiny + eps*t) { - s(ks) = 0.0; - break; - } - } - if (ks == k) { - kase = 3; - } else if (ks == p-1) { - kase = 1; - } else { - kase = 2; - k = ks; - } - } - k++; - - // Perform the task indicated by kase. - - switch (kase) { - - // Deflate negligible s(p). - - case 1: { - Real f(e(p-2)); - e(p-2) = 0.0; - for (j = p-2; j >= k; j--) { - Real t( hypot(s(j), f)); - Real cs(s(j)/t); - Real sn(f/t); - s(j) = t; - if (j != k) { - f = -sn*e(j-1); - e(j-1) = cs*e(j-1); - } - if (wantv) { - for (i = 0; i < n; i++) { - t = cs*V(i, j) + sn*V(i, p-1); - V(i, p-1) = -sn*V(i, j) + cs*V(i, p-1); - V(i, j) = t; - } - } - } - } - break; - - // Split at negligible s(k). - - case 2: { - Real f(e(k-1)); - e(k-1) = 0.0; - for (j = k; j < p; j++) { - Real t(hypot(s(j), f)); - Real cs( s(j)/t); - Real sn(f/t); - s(j) = t; - f = -sn*e(j); - e(j) = cs*e(j); - if (wantu) { - for (i = 0; i < m; i++) { - t = cs*U(i, j) + sn*U(i, k-1); - U(i, k-1) = -sn*U(i, j) + cs*U(i, k-1); - U(i, j) = t; - } - } - } - } - break; - - // Perform one qr step. - - case 3: { - - // Calculate the shift. - - Real scale = std::max(std::max(std::max(std::max( - std::abs(s(p-1)), std::abs(s(p-2))), std::abs(e(p-2))), - std::abs(s(k))), std::abs(e(k))); - Real sp = s(p-1)/scale; - Real spm1 = s(p-2)/scale; - Real epm1 = e(p-2)/scale; - Real sk = s(k)/scale; - Real ek = e(k)/scale; - Real b = ((spm1 + sp)*(spm1 - sp) + epm1*epm1)/2.0; - Real c = (sp*epm1)*(sp*epm1); - Real shift = 0.0; - if ((b != 0.0) || (c != 0.0)) { - shift = std::sqrt(b*b + c); - if (b < 0.0) { - shift = -shift; - } - shift = c/(b + shift); - } - Real f = (sk + sp)*(sk - sp) + shift; - Real g = sk*ek; - - // Chase zeros. - - for (j = k; j < p-1; j++) { - Real t = hypot(f, g); - Real cs = f/t; - Real sn = g/t; - if (j != k) { - e(j-1) = t; - } - f = cs*s(j) + sn*e(j); - e(j) = cs*e(j) - sn*s(j); - g = sn*s(j+1); - s(j+1) = cs*s(j+1); - if (wantv) { - cblas_Xrot(n, vdata + j, vstride, vdata + j+1, vstride, cs, sn); - /*for (i = 0; i < n; i++) { - t = cs*vdata[i*vstride + j] + sn*vdata[i*vstride + j+1]; // t = cs*V(i, j) + sn*V(i, j+1); // 13 - vdata[i*vstride + j+1] = -sn*vdata[i*vstride + j] + cs*vdata[i*vstride + j+1]; // V(i, j+1) = -sn*V(i, j) + cs*V(i, j+1); // 5 - vdata[i*vstride + j] = t; // V(i, j) = t; // 4 - }*/ - } - t = hypot(f, g); - cs = f/t; - sn = g/t; - s(j) = t; - f = cs*e(j) + sn*s(j+1); - s(j+1) = -sn*e(j) + cs*s(j+1); - g = sn*e(j+1); - e(j+1) = cs*e(j+1); - if (wantu && (j < m-1)) { - cblas_Xrot(m, udata + j, ustride, udata + j+1, ustride, cs, sn); - /*for (i = 0; i < m; i++) { - t = cs*udata[i*ustride + j] + sn*udata[i*ustride + j+1]; // t = cs*U(i, j) + sn*U(i, j+1); // 7 - udata[i*ustride + j+1] = -sn*udata[i*ustride + j] +cs*udata[i*ustride + j+1]; // U(i, j+1) = -sn*U(i, j) + cs*U(i, j+1); // 8 - udata[i*ustride + j] = t; // U(i, j) = t; // 1 - }*/ - } - } - e(p-2) = f; - iter = iter + 1; - } - break; - - // Convergence. - - case 4: { - - // Make the singular values positive. - - if (s(k) <= 0.0) { - s(k) = (s(k) < 0.0 ? -s(k) : 0.0); - if (wantv) { - for (i = 0; i <= pp; i++) { - V(i, k) = -V(i, k); - } - } - } - - // Order the singular values. - - while (k < pp) { - if (s(k) >= s(k+1)) { - break; - } - Real t = s(k); - s(k) = s(k+1); - s(k+1) = t; - if (wantv && (k < n-1)) { - for (i = 0; i < n; i++) { - t = V(i, k+1); V(i, k+1) = V(i, k); V(i, k) = t; - } - } - if (wantu && (k < m-1)) { - for (i = 0; i < m; i++) { - t = U(i, k+1); U(i, k+1) = U(i, k); U(i, k) = t; - } - } - k++; - } - iter = 0; - p--; - } - break; - } - } - return true; -} - -#endif // defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) - -} // namespace kaldi - -#endif // KALDI_MATRIX_JAMA_SVD_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-blas.h b/speechx/speechx/kaldi/matrix/kaldi-blas.h deleted file mode 100644 index e8a703c0..00000000 --- a/speechx/speechx/kaldi/matrix/kaldi-blas.h +++ /dev/null @@ -1,139 +0,0 @@ -// matrix/kaldi-blas.h - -// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -#ifndef KALDI_MATRIX_KALDI_BLAS_H_ -#define KALDI_MATRIX_KALDI_BLAS_H_ - -// This file handles the #includes for BLAS, LAPACK and so on. -// It manipulates the declarations into a common format that kaldi can handle. -// However, the kaldi code will check whether HAVE_ATLAS is defined as that -// code is called a bit differently from CLAPACK that comes from other sources. - -// There are three alternatives: -// (i) you have ATLAS, which includes the ATLAS implementation of CBLAS -// plus a subset of CLAPACK (but with clapack_ in the function declarations). -// In this case, define HAVE_ATLAS and make sure the relevant directories are -// in the include path. - -// (ii) you have CBLAS (some implementation thereof) plus CLAPACK. -// In this case, define HAVE_CLAPACK. -// [Since CLAPACK depends on BLAS, the presence of BLAS is implicit]. - -// (iii) you have the MKL library, which includes CLAPACK and CBLAS. - -// Note that if we are using ATLAS, no Svd implementation is supplied, -// so we define HAVE_Svd to be zero and this directs our implementation to -// supply its own "by hand" implementation which is based on TNT code. - - - -#define HAVE_OPENBLAS - -#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \ - || (defined(HAVE_ATLAS) && defined(HAVE_MKL)) -#error "Do not define more than one of HAVE_CLAPACK, HAVE_ATLAS and HAVE_MKL" -#endif - -#ifdef HAVE_ATLAS - extern "C" { - #include "cblas.h" - #include "clapack.h" - } -#elif defined(HAVE_CLAPACK) - #ifdef __APPLE__ - #ifndef __has_extension - #define __has_extension(x) 0 - #endif - #define vImage_Utilities_h - #define vImage_CVUtilities_h - #include - typedef __CLPK_integer integer; - typedef __CLPK_logical logical; - typedef __CLPK_real real; - typedef __CLPK_doublereal doublereal; - typedef __CLPK_complex complex; - typedef __CLPK_doublecomplex doublecomplex; - typedef __CLPK_ftnlen ftnlen; - #else - extern "C" { - // May be in /usr/[local]/include if installed; else this uses the one - // from the tools/CLAPACK_include directory. - #include - #include - #include - - // get rid of macros from f2c.h -- these are dangerous. - #undef abs - #undef dabs - #undef min - #undef max - #undef dmin - #undef dmax - #undef bit_test - #undef bit_clear - #undef bit_set - } - #endif -#elif defined(HAVE_MKL) - extern "C" { - #include - } -#elif defined(HAVE_OPENBLAS) - // getting cblas.h and lapacke.h from /. - // putting in "" not <> to search -I before system libraries. - #if defined(_MSC_VER) - #include - #define LAPACK_COMPLEX_CUSTOM - #define lapack_complex_float _Fcomplex - #define lapack_complex_double _Dcomplex - #endif - #include "cblas.h" - #include "lapacke.h" - #undef I - #undef complex - // get rid of macros from f2c.h -- these are dangerous. - #undef abs - #undef dabs - #undef min - #undef max - #undef dmin - #undef dmax - #undef bit_test - #undef bit_clear - #undef bit_set -#else - #error "You need to define (using the preprocessor) either HAVE_CLAPACK or HAVE_ATLAS or HAVE_MKL (but not more than one)" -#endif - -#ifdef HAVE_OPENBLAS -typedef int KaldiBlasInt; // try int. -#endif -#ifdef HAVE_CLAPACK -typedef integer KaldiBlasInt; -#endif -#ifdef HAVE_MKL -typedef MKL_INT KaldiBlasInt; -#endif - -#ifdef HAVE_ATLAS -// in this case there is no need for KaldiBlasInt-- this typedef is only needed -// for Svd code which is not included in ATLAS (we re-implement it). -#endif - - -#endif // KALDI_MATRIX_KALDI_BLAS_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector.h b/speechx/speechx/kaldi/matrix/kaldi-vector.h deleted file mode 100644 index 2a032354..00000000 --- a/speechx/speechx/kaldi/matrix/kaldi-vector.h +++ /dev/null @@ -1,612 +0,0 @@ -// matrix/kaldi-vector.h - -// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; -// Saarland University (Author: Arnab Ghoshal); -// Ariya Rastrow; Petr Schwarz; Yanmin Qian; -// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal -// Wei Shi; -// 2015 Guoguo Chen -// 2017 Daniel Galvez -// 2019 Yiwen Shao - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ -#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 - -#include "matrix/matrix-common.h" - -namespace kaldi { - -/// \addtogroup matrix_group -/// @{ - -/// Provides a vector abstraction class. -/// This class provides a way to work with vectors in kaldi. -/// It encapsulates basic operations and memory optimizations. -template -class VectorBase { - public: - /// Set vector to all zeros. - void SetZero(); - - /// Returns true if matrix is all zeros. - bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number - - /// Set all members of a vector to a specified value. - void Set(Real f); - - /// Set vector to random normally-distributed noise. - void SetRandn(); - - /// Sets to numbers uniformly distributed on (0,1) - void SetRandUniform(); - - /// This function returns a random index into this vector, - /// chosen with probability proportional to the corresponding - /// element. Requires that this->Min() >= 0 and this->Sum() > 0. - MatrixIndexT RandCategorical() const; - - /// Returns the dimension of the vector. - inline MatrixIndexT Dim() const { return dim_; } - - /// Returns the size in memory of the vector, in bytes. - inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } - - /// Returns a pointer to the start of the vector's data. - inline Real* Data() { return data_; } - - /// Returns a pointer to the start of the vector's data (const). - inline const Real* Data() const { return data_; } - - /// Indexing operator (const). - inline Real operator() (MatrixIndexT i) const { - KALDI_PARANOID_ASSERT(static_cast(i) < - static_cast(dim_)); - return *(data_ + i); - } - - /// Indexing operator (non-const). - inline Real & operator() (MatrixIndexT i) { - KALDI_PARANOID_ASSERT(static_cast(i) < - static_cast(dim_)); - return *(data_ + i); - } - - /** @brief Returns a sub-vector of a vector (a range of elements). - * @param o [in] Origin, 0 < o < Dim() - * @param l [in] Length 0 < l < Dim()-o - * @return A SubVector object that aliases the data of the Vector object. - * See @c SubVector class for details */ - SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { - return SubVector(*this, o, l); - } - - /** @brief Returns a const sub-vector of a vector (a range of elements). - * @param o [in] Origin, 0 < o < Dim() - * @param l [in] Length 0 < l < Dim()-o - * @return A SubVector object that aliases the data of the Vector object. - * See @c SubVector class for details */ - const SubVector Range(const MatrixIndexT o, - const MatrixIndexT l) const { - return SubVector(*this, o, l); - } - - /// Copy data from another vector (must match own size). - void CopyFromVec(const VectorBase &v); - - /// Copy data from a SpMatrix or TpMatrix (must match own size). - template - void CopyFromPacked(const PackedMatrix &M); - - /// Copy data from another vector of different type (double vs. float) - template - void CopyFromVec(const VectorBase &v); - - /// Copy from CuVector. This is defined in ../cudamatrix/cu-vector.h - template - void CopyFromVec(const CuVectorBase &v); - - /// Applies floor to all elements. Returns number of elements - /// floored in floored_count if it is non-null. - void Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count = nullptr); - - /// Applies ceiling to all elements. Returns number of elements - /// changed in ceiled_count if it is non-null. - void Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count = nullptr); - - void Pow(const VectorBase &v, Real power); - - /// Apply natural log to all elements. Throw if any element of - /// the vector is negative (but doesn't complain about zero; the - /// log will be -infinity - void ApplyLog(); - - /// Apply natural log to another vector and put result in *this. - void ApplyLogAndCopy(const VectorBase &v); - - /// Apply exponential to each value in vector. - void ApplyExp(); - - /// Take absolute value of each of the elements - void ApplyAbs(); - - /// Applies floor to all elements. Returns number of elements - /// floored in floored_count if it is non-null. - inline void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = nullptr) { - this->Floor(*this, floor_val, floored_count); - }; - - /// Applies ceiling to all elements. Returns number of elements - /// changed in ceiled_count if it is non-null. - inline void ApplyCeiling(Real ceil_val, MatrixIndexT *ceiled_count = nullptr) { - this->Ceiling(*this, ceil_val, ceiled_count); - }; - - /// Applies floor to all elements. Returns number of elements floored. - MatrixIndexT ApplyFloor(const VectorBase &floor_vec); - - /// Apply soft-max to vector and return normalizer (log sum of exponentials). - /// This is the same as: \f$ x(i) = exp(x(i)) / \sum_i exp(x(i)) \f$ - Real ApplySoftMax(); - - /// Applies log soft-max to vector and returns normalizer (log sum of - /// exponentials). - /// This is the same as: \f$ x(i) = x(i) - log(\sum_i exp(x(i))) \f$ - Real ApplyLogSoftMax(); - - /// Sets each element of *this to the tanh of the corresponding element of "src". - void Tanh(const VectorBase &src); - - /// Sets each element of *this to the sigmoid function of the corresponding - /// element of "src". - void Sigmoid(const VectorBase &src); - - /// Take all elements of vector to a power. - inline void ApplyPow(Real power) { - this->Pow(*this, power); - }; - - /// Take the absolute value of all elements of a vector to a power. - /// Include the sign of the input element if include_sign == true. - /// If power is negative and the input value is zero, the output is set zero. - void ApplyPowAbs(Real power, bool include_sign=false); - - /// Compute the p-th norm of the vector. - Real Norm(Real p) const; - - /// Returns true if ((*this)-other).Norm(2.0) <= tol * (*this).Norm(2.0). - bool ApproxEqual(const VectorBase &other, float tol = 0.01) const; - - /// Invert all elements. - void InvertElements(); - - /// Add vector : *this = *this + alpha * rv (with casting between floats and - /// doubles) - template - void AddVec(const Real alpha, const VectorBase &v); - - /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring]. - void AddVec2(const Real alpha, const VectorBase &v); - - /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring], - /// with casting between floats and doubles. - template - void AddVec2(const Real alpha, const VectorBase &v); - - /// Add matrix times vector : this <-- beta*this + alpha*M*v. - /// Calls BLAS GEMV. - void AddMatVec(const Real alpha, const MatrixBase &M, - const MatrixTransposeType trans, const VectorBase &v, - const Real beta); // **beta previously defaulted to 0.0** - - /// This is as AddMatVec, except optimized for where v contains a lot - /// of zeros. - void AddMatSvec(const Real alpha, const MatrixBase &M, - const MatrixTransposeType trans, const VectorBase &v, - const Real beta); // **beta previously defaulted to 0.0** - - - /// Add symmetric positive definite matrix times vector: - /// this <-- beta*this + alpha*M*v. Calls BLAS SPMV. - void AddSpVec(const Real alpha, const SpMatrix &M, - const VectorBase &v, const Real beta); // **beta previously defaulted to 0.0** - - /// Add triangular matrix times vector: this <-- beta*this + alpha*M*v. - /// Works even if rv == *this. - void AddTpVec(const Real alpha, const TpMatrix &M, - const MatrixTransposeType trans, const VectorBase &v, - const Real beta); // **beta previously defaulted to 0.0** - - /// Set each element to y = (x == orig ? changed : x). - void ReplaceValue(Real orig, Real changed); - - /// Multiply element-by-element by another vector. - void MulElements(const VectorBase &v); - /// Multiply element-by-element by another vector of different type. - template - void MulElements(const VectorBase &v); - - /// Divide element-by-element by a vector. - void DivElements(const VectorBase &v); - /// Divide element-by-element by a vector of different type. - template - void DivElements(const VectorBase &v); - - /// Add a constant to each element of a vector. - void Add(Real c); - - /// Add element-by-element product of vectors: - // this <-- alpha * v .* r + beta*this . - void AddVecVec(Real alpha, const VectorBase &v, - const VectorBase &r, Real beta); - - /// Add element-by-element quotient of two vectors. - /// this <---- alpha*v/r + beta*this - void AddVecDivVec(Real alpha, const VectorBase &v, - const VectorBase &r, Real beta); - - /// Multiplies all elements by this constant. - void Scale(Real alpha); - - /// Multiplies this vector by lower-triangular matrix: *this <-- *this *M - void MulTp(const TpMatrix &M, const MatrixTransposeType trans); - - /// If trans == kNoTrans, solves M x = b, where b is the value of *this at input - /// and x is the value of *this at output. - /// If trans == kTrans, solves M' x = b. - /// Does not test for M being singular or near-singular, so test it before - /// calling this routine. - void Solve(const TpMatrix &M, const MatrixTransposeType trans); - - /// Performs a row stack of the matrix M - void CopyRowsFromMat(const MatrixBase &M); - template - void CopyRowsFromMat(const MatrixBase &M); - - /// The following is implemented in ../cudamatrix/cu-matrix.cc - void CopyRowsFromMat(const CuMatrixBase &M); - - /// Performs a column stack of the matrix M - void CopyColsFromMat(const MatrixBase &M); - - /// Extracts a row of the matrix M. Could also do this with - /// this->Copy(M[row]). - void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); - /// Extracts a row of the matrix M with type conversion. - template - void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); - - /// Extracts a row of the symmetric matrix S. - template - void CopyRowFromSp(const SpMatrix &S, MatrixIndexT row); - - /// Extracts a column of the matrix M. - template - void CopyColFromMat(const MatrixBase &M , MatrixIndexT col); - - /// Extracts the diagonal of the matrix M. - void CopyDiagFromMat(const MatrixBase &M); - - /// Extracts the diagonal of a packed matrix M; works for Sp or Tp. - void CopyDiagFromPacked(const PackedMatrix &M); - - - /// Extracts the diagonal of a symmetric matrix. - inline void CopyDiagFromSp(const SpMatrix &M) { CopyDiagFromPacked(M); } - - /// Extracts the diagonal of a triangular matrix. - inline void CopyDiagFromTp(const TpMatrix &M) { CopyDiagFromPacked(M); } - - /// Returns the maximum value of any element, or -infinity for the empty vector. - Real Max() const; - - /// Returns the maximum value of any element, and the associated index. - /// Error if vector is empty. - Real Max(MatrixIndexT *index) const; - - /// Returns the minimum value of any element, or +infinity for the empty vector. - Real Min() const; - - /// Returns the minimum value of any element, and the associated index. - /// Error if vector is empty. - Real Min(MatrixIndexT *index) const; - - /// Returns sum of the elements - Real Sum() const; - - /// Returns sum of the logs of the elements. More efficient than - /// just taking log of each. Will return NaN if any elements are - /// negative. - Real SumLog() const; - - /// Does *this = alpha * (sum of rows of M) + beta * *this. - void AddRowSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0); - - /// Does *this = alpha * (sum of columns of M) + beta * *this. - void AddColSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0); - - /// Add the diagonal of a matrix times itself: - /// *this = diag(M M^T) + beta * *this (if trans == kNoTrans), or - /// *this = diag(M^T M) + beta * *this (if trans == kTrans). - void AddDiagMat2(Real alpha, const MatrixBase &M, - MatrixTransposeType trans = kNoTrans, Real beta = 1.0); - - /// Add the diagonal of a matrix product: *this = diag(M N), assuming the - /// "trans" arguments are both kNoTrans; for transpose arguments, it behaves - /// as you would expect. - void AddDiagMatMat(Real alpha, const MatrixBase &M, MatrixTransposeType transM, - const MatrixBase &N, MatrixTransposeType transN, - Real beta = 1.0); - - /// Returns log(sum(exp())) without exp overflow - /// If prune > 0.0, ignores terms less than the max - prune. - /// [Note: in future, if prune = 0.0, it will take the max. - /// For now, use -1 if you don't want it to prune.] - Real LogSumExp(Real prune = -1.0) const; - - /// Reads from C++ stream (option to add to existing contents). - /// Throws exception on failure - void Read(std::istream &in, bool binary, bool add = false); - - /// Writes to C++ stream (option to write in binary). - void Write(std::ostream &Out, bool binary) const; - - friend class VectorBase; - friend class VectorBase; - friend class CuVectorBase; - friend class CuVector; - protected: - /// Destructor; does not deallocate memory, this is handled by child classes. - /// This destructor is protected so this object can only be - /// deleted via a child. - ~VectorBase() {} - - /// Empty initializer, corresponds to vector of zero size. - explicit VectorBase(): data_(NULL), dim_(0) { - KALDI_ASSERT_IS_FLOATING_TYPE(Real); - } - -// Took this out since it is not currently used, and it is possible to create -// objects where the allocated memory is not the same size as dim_ : Arnab -// /// Initializer from a pointer and a size; keeps the pointer internally -// /// (ownership or non-ownership depends on the child class). -// explicit VectorBase(Real* data, MatrixIndexT dim) -// : data_(data), dim_(dim) {} - - // Arnab : made this protected since it is unsafe too. - /// Load data into the vector: sz must match own size. - void CopyFromPtr(const Real* Data, MatrixIndexT sz); - - /// data memory area - Real* data_; - /// dimension of vector - MatrixIndexT dim_; - KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); -}; // class VectorBase - -/** @brief A class representing a vector. - * - * This class provides a way to work with vectors in kaldi. - * It encapsulates basic operations and memory optimizations. */ -template -class Vector: public VectorBase { - public: - /// Constructor that takes no arguments. Initializes to empty. - Vector(): VectorBase() {} - - /// Constructor with specific size. Sets to all-zero by default - /// if set_zero == false, memory contents are undefined. - explicit Vector(const MatrixIndexT s, - MatrixResizeType resize_type = kSetZero) - : VectorBase() { Resize(s, resize_type); } - - /// Copy constructor from CUDA vector - /// This is defined in ../cudamatrix/cu-vector.h - template - explicit Vector(const CuVectorBase &cu); - - /// Copy constructor. The need for this is controversial. - Vector(const Vector &v) : VectorBase() { // (cannot be explicit) - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - - /// Copy-constructor from base-class, needed to copy from SubVector. - explicit Vector(const VectorBase &v) : VectorBase() { - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - - /// Type conversion constructor. - template - explicit Vector(const VectorBase &v): VectorBase() { - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - -// Took this out since it is unsafe : Arnab -// /// Constructor from a pointer and a size; copies the data to a location -// /// it owns. -// Vector(const Real* Data, const MatrixIndexT s): VectorBase() { -// Resize(s); - // CopyFromPtr(Data, s); -// } - - - /// Swaps the contents of *this and *other. Shallow swap. - void Swap(Vector *other); - - /// Destructor. Deallocates memory. - ~Vector() { Destroy(); } - - /// Read function using C++ streams. Can also add to existing contents - /// of matrix. - void Read(std::istream &in, bool binary, bool add = false); - - /// Set vector to a specified size (can be zero). - /// The value of the new data depends on resize_type: - /// -if kSetZero, the new data will be zero - /// -if kUndefined, the new data will be undefined - /// -if kCopyData, the new data will be the same as the old data in any - /// shared positions, and zero elsewhere. - /// This function takes time proportional to the number of data elements. - void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); - - /// Remove one element and shifts later elements down. - void RemoveElement(MatrixIndexT i); - - /// Assignment operator. - Vector &operator = (const Vector &other) { - Resize(other.Dim(), kUndefined); - this->CopyFromVec(other); - return *this; - } - - /// Assignment operator that takes VectorBase. - Vector &operator = (const VectorBase &other) { - Resize(other.Dim(), kUndefined); - this->CopyFromVec(other); - return *this; - } - private: - /// Init assumes the current contents of the class are invalid (i.e. junk or - /// has already been freed), and it sets the vector to newly allocated memory - /// with the specified dimension. dim == 0 is acceptable. The memory contents - /// pointed to by data_ will be undefined. - void Init(const MatrixIndexT dim); - - /// Destroy function, called internally. - void Destroy(); - -}; - - -/// Represents a non-allocating general vector which can be defined -/// as a sub-vector of higher-level vector [or as the row of a matrix]. -template -class SubVector : public VectorBase { - public: - /// Constructor from a Vector or SubVector. - /// SubVectors are not const-safe and it's very hard to make them - /// so for now we just give up. This function contains const_cast. - SubVector(const VectorBase &t, const MatrixIndexT origin, - const MatrixIndexT length) : VectorBase() { - // following assert equiv to origin>=0 && length>=0 && - // origin+length <= rt.dim_ - KALDI_ASSERT(static_cast(origin)+ - static_cast(length) <= - static_cast(t.Dim())); - VectorBase::data_ = const_cast (t.Data()+origin); - VectorBase::dim_ = length; - } - - /// This constructor initializes the vector to point at the contents - /// of this packed matrix (SpMatrix or TpMatrix). - SubVector(const PackedMatrix &M) { - VectorBase::data_ = const_cast (M.Data()); - VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; - } - - /// Copy constructor - SubVector(const SubVector &other) : VectorBase () { - // this copy constructor needed for Range() to work in base class. - VectorBase::data_ = other.data_; - VectorBase::dim_ = other.dim_; - } - - /// Constructor from a pointer to memory and a length. Keeps a pointer - /// to the data but does not take ownership (will never delete). - /// Caution: this constructor enables you to evade const constraints. - SubVector(const Real *data, MatrixIndexT length) : VectorBase () { - VectorBase::data_ = const_cast(data); - VectorBase::dim_ = length; - } - - /// This operation does not preserve const-ness, so be careful. - SubVector(const MatrixBase &matrix, MatrixIndexT row) { - VectorBase::data_ = const_cast(matrix.RowData(row)); - VectorBase::dim_ = matrix.NumCols(); - } - - ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). - - private: - /// Disallow assignment operator. - SubVector & operator = (const SubVector &other) {} -}; - -/// @} end of "addtogroup matrix_group" -/// \addtogroup matrix_funcs_io -/// @{ -/// Output to a C++ stream. Non-binary by default (use Write for -/// binary output). -template -std::ostream & operator << (std::ostream & out, const VectorBase & v); - -/// Input from a C++ stream. Will automatically read text or -/// binary data from the stream. -template -std::istream & operator >> (std::istream & in, VectorBase & v); - -/// Input from a C++ stream. Will automatically read text or -/// binary data from the stream. -template -std::istream & operator >> (std::istream & in, Vector & v); -/// @} end of \addtogroup matrix_funcs_io - -/// \addtogroup matrix_funcs_scalar -/// @{ - - -template -bool ApproxEqual(const VectorBase &a, - const VectorBase &b, Real tol = 0.01) { - return a.ApproxEqual(b, tol); -} - -template -inline void AssertEqual(VectorBase &a, VectorBase &b, - float tol = 0.01) { - KALDI_ASSERT(a.ApproxEqual(b, tol)); -} - - -/// Returns dot product between v1 and v2. -template -Real VecVec(const VectorBase &v1, const VectorBase &v2); - -template -Real VecVec(const VectorBase &v1, const VectorBase &v2); - - -/// Returns \f$ v_1^T M v_2 \f$ . -/// Not as efficient as it could be where v1 == v2. -template -Real VecMatVec(const VectorBase &v1, const MatrixBase &M, - const VectorBase &v2); - -/// @} End of "addtogroup matrix_funcs_scalar" - - -} // namespace kaldi - -// we need to include the implementation -#include "matrix/kaldi-vector-inl.h" - - - -#endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-functions-inl.h b/speechx/speechx/kaldi/matrix/matrix-functions-inl.h deleted file mode 100644 index 9fac851e..00000000 --- a/speechx/speechx/kaldi/matrix/matrix-functions-inl.h +++ /dev/null @@ -1,56 +0,0 @@ -// matrix/matrix-functions-inl.h - -// Copyright 2009-2011 Microsoft Corporation -// -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// (*) incorporates, with permission, FFT code from his book -// "Signal Processing with Lapped Transforms", Artech, 1992. - - - -#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ -#define KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ - -namespace kaldi { - -//! ComplexMul implements, inline, the complex multiplication b *= a. -template inline void ComplexMul(const Real &a_re, const Real &a_im, - Real *b_re, Real *b_im) { - Real tmp_re = (*b_re * a_re) - (*b_im * a_im); - *b_im = *b_re * a_im + *b_im * a_re; - *b_re = tmp_re; -} - -template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, - const Real &b_re, const Real &b_im, - Real *c_re, Real *c_im) { - *c_re += b_re*a_re - b_im*a_im; - *c_im += b_re*a_im + b_im*a_re; -} - - -template inline void ComplexImExp(Real x, Real *a_re, Real *a_im) { - *a_re = std::cos(x); - *a_im = std::sin(x); -} - - -} // end namespace kaldi - - -#endif // KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ - diff --git a/speechx/speechx/kaldi/matrix/matrix-functions.cc b/speechx/speechx/kaldi/matrix/matrix-functions.cc deleted file mode 100644 index 496c09f5..00000000 --- a/speechx/speechx/kaldi/matrix/matrix-functions.cc +++ /dev/null @@ -1,773 +0,0 @@ -// matrix/matrix-functions.cc - -// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky -// Yanmin Qian; Saarland University; Johns Hopkins University (Author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// (*) incorporates, with permission, FFT code from his book -// "Signal Processing with Lapped Transforms", Artech, 1992. - -#include "matrix/matrix-functions.h" -#include "matrix/sp-matrix.h" - -namespace kaldi { - -template void ComplexFt (const VectorBase &in, - VectorBase *out, bool forward) { - int exp_sign = (forward ? -1 : 1); - KALDI_ASSERT(out != NULL); - KALDI_ASSERT(in.Dim() == out->Dim()); - KALDI_ASSERT(in.Dim() % 2 == 0); - int twoN = in.Dim(), N = twoN / 2; - const Real *data_in = in.Data(); - Real *data_out = out->Data(); - - Real exp1N_re, exp1N_im; // forward -> exp(-2pi / N), backward -> exp(2pi / N). - Real fraction = exp_sign * M_2PI / static_cast(N); // forward -> -2pi/N, backward->-2pi/N - ComplexImExp(fraction, &exp1N_re, &exp1N_im); - - Real expm_re = 1.0, expm_im = 0.0; // forward -> exp(-2pi m / N). - - for (int two_m = 0; two_m < twoN; two_m+=2) { // For each output component. - Real expmn_re = 1.0, expmn_im = 0.0; // forward -> exp(-2pi m n / N). - Real sum_re = 0.0, sum_im = 0.0; // complex output for index m (the sum expression) - for (int two_n = 0; two_n < twoN; two_n+=2) { - ComplexAddProduct(data_in[two_n], data_in[two_n+1], - expmn_re, expmn_im, - &sum_re, &sum_im); - ComplexMul(expm_re, expm_im, &expmn_re, &expmn_im); - } - data_out[two_m] = sum_re; - data_out[two_m + 1] = sum_im; - - - if (two_m % 10 == 0) { // occasionally renew "expm" from scratch to avoid - // loss of precision. - int nextm = 1 + two_m/2; - Real fraction_mult = fraction * nextm; - ComplexImExp(fraction_mult, &expm_re, &expm_im); - } else { - ComplexMul(exp1N_re, exp1N_im, &expm_re, &expm_im); - } - } -} - -template -void ComplexFt (const VectorBase &in, - VectorBase *out, bool forward); -template -void ComplexFt (const VectorBase &in, - VectorBase *out, bool forward); - - -#define KALDI_COMPLEXFFT_BLOCKSIZE 8192 -// This #define affects how we recurse in ComplexFftRecursive. -// We assume that memory-caching happens on a scale at -// least as small as this. - - -//! ComplexFftRecursive is a recursive function that computes the -//! complex FFT of size N. The "nffts" arguments specifies how many -//! separate FFTs to compute in parallel (we assume the data for -//! each one is consecutive in memory). The "forward argument" -//! specifies whether to do the FFT (true) or IFFT (false), although -//! note that we do not include the factor of 1/N (the user should -//! do this if required. The iterators factor_begin and factor_end -//! point to the beginning and end (i.e. one past the last element) -//! of an array of small factors of N (typically prime factors). -//! See the comments below this code for the detailed equations -//! of the recursion. - - -template -void ComplexFftRecursive (Real *data, int nffts, int N, - const int *factor_begin, - const int *factor_end, bool forward, - Vector *tmp_vec) { - if (factor_begin == factor_end) { - KALDI_ASSERT(N == 1); - return; - } - - { // an optimization: compute in smaller blocks. - // this block of code could be removed and it would still work. - MatrixIndexT size_perblock = N * 2 * sizeof(Real); - if (nffts > 1 && size_perblock*nffts > KALDI_COMPLEXFFT_BLOCKSIZE) { // can break it up... - // Break up into multiple blocks. This is an optimization. We make - // no progress on the FFT when we do this. - int block_skip = KALDI_COMPLEXFFT_BLOCKSIZE / size_perblock; // n blocks per call - if (block_skip == 0) block_skip = 1; - if (block_skip < nffts) { - int blocks_left = nffts; - while (blocks_left > 0) { - int skip_now = std::min(blocks_left, block_skip); - ComplexFftRecursive(data, skip_now, N, factor_begin, factor_end, forward, tmp_vec); - blocks_left -= skip_now; - data += skip_now * N*2; - } - return; - } // else do the actual algorithm. - } // else do the actual algorithm. - } - - int P = *factor_begin; - KALDI_ASSERT(P > 1); - int Q = N / P; - - - if (P > 1 && Q > 1) { // Do the rearrangement. C.f. eq. (8) below. Transform - // (a) to (b). - Real *data_thisblock = data; - if (tmp_vec->Dim() < (MatrixIndexT)N) tmp_vec->Resize(N); - Real *data_tmp = tmp_vec->Data(); - for (int thisfft = 0; thisfft < nffts; thisfft++, data_thisblock+=N*2) { - for (int offset = 0; offset < 2; offset++) { // 0 == real, 1 == im. - for (int p = 0; p < P; p++) { - for (int q = 0; q < Q; q++) { - int aidx = q*P + p, bidx = p*Q + q; - data_tmp[bidx] = data_thisblock[2*aidx+offset]; - } - } - for (int n = 0;n < P*Q;n++) data_thisblock[2*n+offset] = data_tmp[n]; - } - } - } - - { // Recurse. - ComplexFftRecursive(data, nffts*P, Q, factor_begin+1, factor_end, forward, tmp_vec); - } - - int exp_sign = (forward ? -1 : 1); - Real rootN_re, rootN_im; // Nth root of unity. - ComplexImExp(static_cast(exp_sign * M_2PI / N), &rootN_re, &rootN_im); - - Real rootP_re, rootP_im; // Pth root of unity. - ComplexImExp(static_cast(exp_sign * M_2PI / P), &rootP_re, &rootP_im); - - { // Do the multiplication - // could avoid a bunch of complex multiplies by moving the loop over data_thisblock - // inside. - if (tmp_vec->Dim() < (MatrixIndexT)(P*2)) tmp_vec->Resize(P*2); - Real *temp_a = tmp_vec->Data(); - - Real *data_thisblock = data, *data_end = data+(N*2*nffts); - for (; data_thisblock != data_end; data_thisblock += N*2) { // for each separate fft. - Real qd_re = 1.0, qd_im = 0.0; // 1^(q'/N) - for (int qd = 0; qd < Q; qd++) { - Real pdQ_qd_re = qd_re, pdQ_qd_im = qd_im; // 1^((p'Q+q') / N) == 1^((p'/P) + (q'/N)) - // Initialize to q'/N, corresponding to p' == 0. - for (int pd = 0; pd < P; pd++) { // pd == p' - { // This is the p = 0 case of the loop below [an optimization]. - temp_a[pd*2] = data_thisblock[qd*2]; - temp_a[pd*2 + 1] = data_thisblock[qd*2 + 1]; - } - { // This is the p = 1 case of the loop below [an optimization] - // **** MOST OF THE TIME (>60% I think) gets spent here. *** - ComplexAddProduct(pdQ_qd_re, pdQ_qd_im, - data_thisblock[(qd+Q)*2], data_thisblock[(qd+Q)*2 + 1], - &(temp_a[pd*2]), &(temp_a[pd*2 + 1])); - } - if (P > 2) { - Real p_pdQ_qd_re = pdQ_qd_re, p_pdQ_qd_im = pdQ_qd_im; // 1^(p(p'Q+q')/N) - for (int p = 2; p < P; p++) { - ComplexMul(pdQ_qd_re, pdQ_qd_im, &p_pdQ_qd_re, &p_pdQ_qd_im); // p_pdQ_qd *= pdQ_qd. - int data_idx = p*Q + qd; - ComplexAddProduct(p_pdQ_qd_re, p_pdQ_qd_im, - data_thisblock[data_idx*2], data_thisblock[data_idx*2 + 1], - &(temp_a[pd*2]), &(temp_a[pd*2 + 1])); - } - } - if (pd != P-1) - ComplexMul(rootP_re, rootP_im, &pdQ_qd_re, &pdQ_qd_im); // pdQ_qd *= (rootP == 1^{1/P}) - // (using 1/P == Q/N) - } - for (int pd = 0; pd < P; pd++) { - data_thisblock[(pd*Q + qd)*2] = temp_a[pd*2]; - data_thisblock[(pd*Q + qd)*2 + 1] = temp_a[pd*2 + 1]; - } - ComplexMul(rootN_re, rootN_im, &qd_re, &qd_im); // qd *= rootN. - } - } - } -} - -/* Equations for ComplexFftRecursive. - We consider here one of the "nffts" separate ffts; it's just a question of - doing them all in parallel. We also write all equations in terms of - complex math (the conversion to real arithmetic is not hard, and anyway - takes place inside function calls). - - - Let the input (i.e. "data" at start) be a_n, n = 0..N-1, and - the output (Fourier transform) be d_k, k = 0..N-1. We use these letters because - there will be two intermediate variables b and c. - We want to compute: - - d_k = \sum_n a_n 1^(kn/N) (1) - - where we use 1^x as shorthand for exp(-2pi x) for the forward algorithm - and exp(2pi x) for the backward one. - - We factorize N = P Q (P small, Q usually large). - With p = 0..P-1 and q = 0..Q-1, and also p'=0..P-1 and q'=0..P-1, we let: - - k == p'Q + q' (2) - n == qP + p (3) - - That is, we let p, q, p', q' range over these indices and observe that this way we - can cover all n, k. Expanding (1) using (2) and (3), we can write: - - d_k = \sum_{p, q} a_n 1^((p'Q+q')(qP+p)/N) - = \sum_{p, q} a_n 1^(p'pQ/N) 1^(q'qP/N) 1^(q'p/N) (4) - - using 1^(PQ/N) = 1 to get rid of the terms with PQ in them. Rearranging (4), - - d_k = \sum_p 1^(p'pQ/N) 1^(q'p/N) \sum_q 1^(q'qP/N) a_n (5) - - The point here is to separate the index q. Now we can expand out the remaining - instances of k and n using (2) and (3): - - d_(p'Q+q') = \sum_p 1^(p'pQ/N) 1^(q'p/N) \sum_q 1^(q'qP/N) a_(qP+p) (6) - - The expression \sum_q varies with the indices p and q'. Let us define - - C_{p, q'} = \sum_q 1^(q'qP/N) a_(qP+p) (7) - - Here, C_{p, q'}, viewed as a sequence in q', is just the DFT of the points - a_(qP+p) for q = 1..Q-1. These points are not consecutive in memory though, - they jump by P each time. Let us define b as a rearranged version of a, - so that - - b_(pQ+q) = a_(qP+p) (8) - - How to do this rearrangement in place? In - - We can rearrange (7) to be written in terms of the b's, using (8), so that - - C_{p, q'} = \sum_q 1^(q'q (P/N)) b_(pQ+q) (9) - - Here, the sequence of C_{p, q'} over q'=0..Q-1, is just the DFT of the sequence - of b_(pQ) .. b_(p(Q+1)-1). Let's arrange the C_{p, q'} in a single array in - memory in the same way as the b's, i.e. we define - c_(pQ+q') == C_{p, q'}. (10) - Note that we could have written (10) with q in place of q', as there is only - one index of type q present, but q' is just a more natural variable name to use - since we use q' elsewhere to subscript c and C. - - Rewriting (9), we have: - c_(pQ+q') = \sum_q 1^(q'q (P/N)) b_(pQ+q) (11) - which is the DFT computed by the recursive call to this function [after computing - the b's by rearranging the a's]. From the c's we want to compute the d's. - Taking (6), substituting in the sum (7), and using (10) to write it as an array, - we have: - d_(p'Q+q') = \sum_p 1^(p'pQ/N) 1^(q'p/N) c_(pQ+q') (12) - This sum is independent for different values of q'. Note that d overwrites c - in memory. We compute this in a direct way, using a little array of size P to - store the computed d values for one value of q' (we reuse the array for each value - of q'). - - So the overall picture is this: - We get a call to compute DFT on size N. - - - If N == 1 we return (nothing to do). - - We factor N = P Q (typically, P is small). - - Using (8), we rearrange the data in memory so that we have b not a in memory - (this is the block "do the rearrangement"). - The pseudocode for this is as follows. For simplicity we use a temporary array. - - for p = 0..P-1 - for q = 0..Q-1 - bidx = pQ + q - aidx = qP + p - tmp[bidx] = data[aidx]. - end - end - data <-- tmp - else - - endif - - - The reason this accomplishes (8) is that we want pQ+q and qP+p to be swapped - over for each p, q, and the "if m > n" is a convenient way of ensuring that - this swapping happens only once (otherwise it would happen twice, since pQ+q - and qP+p both range over the entire set of numbers 0..N-1). - - - We do the DFT on the smaller block size to compute c from b (this eq eq. (11)). - Note that this is actually multiple DFTs, one for each value of p, but this - goes to the "nffts" argument of the function call, which we have ignored up to now. - - -We compute eq. (12) via a loop, as follows - allocate temporary array e of size P. - For q' = 0..Q-1: - for p' = 0..P-1: - set sum to zero [this will go in e[p']] - for p = p..P-1: - sum += 1^(p'pQ/N) 1^(q'p/N) c_(pQ+q') - end - e[p'] = sum - end - for p' = 0..P-1: - d_(p'Q+q') = e[p'] - end - end - delete temporary array e - -*/ - -// This is the outer-layer calling code for ComplexFftRecursive. -// It factorizes the dimension and then calls the FFT routine. -template void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in) { - KALDI_ASSERT(v != NULL); - - if (v->Dim()<=1) return; - KALDI_ASSERT(v->Dim() % 2 == 0); // complex input. - int N = v->Dim() / 2; - std::vector factors; - Factorize(N, &factors); - int *factor_beg = NULL; - if (factors.size() > 0) - factor_beg = &(factors[0]); - Vector tmp; // allocated in ComplexFftRecursive. - ComplexFftRecursive(v->Data(), 1, N, factor_beg, factor_beg+factors.size(), forward, (tmp_in?tmp_in:&tmp)); -} - -//! Inefficient version of Fourier transform, for testing purposes. -template void RealFftInefficient (VectorBase *v, bool forward) { - KALDI_ASSERT(v != NULL); - MatrixIndexT N = v->Dim(); - KALDI_ASSERT(N%2 == 0); - if (N == 0) return; - Vector vtmp(N*2); // store as complex. - if (forward) { - for (MatrixIndexT i = 0; i < N; i++) vtmp(i*2) = (*v)(i); - ComplexFft(&vtmp, forward); // this is already tested so we can use this. - v->CopyFromVec( vtmp.Range(0, N) ); - (*v)(1) = vtmp(N); // Copy the N/2'th fourier component, which is real, - // to the imaginary part of the 1st complex output. - } else { - // reverse the transformation above to get the complex spectrum. - vtmp(0) = (*v)(0); // copy F_0 which is real - vtmp(N) = (*v)(1); // copy F_{N/2} which is real - for (MatrixIndexT i = 1; i < N/2; i++) { - // Copy i'th to i'th fourier component - vtmp(2*i) = (*v)(2*i); - vtmp(2*i+1) = (*v)(2*i+1); - // Copy i'th to N-i'th, conjugated. - vtmp(2*(N-i)) = (*v)(2*i); - vtmp(2*(N-i)+1) = -(*v)(2*i+1); - } - ComplexFft(&vtmp, forward); // actually backward since forward == false - // Copy back real part. Complex part should be zero. - for (MatrixIndexT i = 0; i < N; i++) - (*v)(i) = vtmp(i*2); - } -} - -template void RealFftInefficient (VectorBase *v, bool forward); -template void RealFftInefficient (VectorBase *v, bool forward); - -template -void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in); -template -void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in); - - -// See the long comment below for the math behind this. -template void RealFft (VectorBase *v, bool forward) { - KALDI_ASSERT(v != NULL); - MatrixIndexT N = v->Dim(), N2 = N/2; - KALDI_ASSERT(N%2 == 0); - if (N == 0) return; - - if (forward) ComplexFft(v, true); - - Real *data = v->Data(); - Real rootN_re, rootN_im; // exp(-2pi/N), forward; exp(2pi/N), backward - int forward_sign = forward ? -1 : 1; - ComplexImExp(static_cast(M_2PI/N *forward_sign), &rootN_re, &rootN_im); - Real kN_re = -forward_sign, kN_im = 0.0; // exp(-2pik/N), forward; exp(-2pik/N), backward - // kN starts out as 1.0 for forward algorithm but -1.0 for backward. - for (MatrixIndexT k = 1; 2*k <= N2; k++) { - ComplexMul(rootN_re, rootN_im, &kN_re, &kN_im); - - Real Ck_re, Ck_im, Dk_re, Dk_im; - // C_k = 1/2 (B_k + B_{N/2 - k}^*) : - Ck_re = 0.5 * (data[2*k] + data[N - 2*k]); - Ck_im = 0.5 * (data[2*k + 1] - data[N - 2*k + 1]); - // re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})): - Dk_re = 0.5 * (data[2*k + 1] + data[N - 2*k + 1]); - // im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) - Dk_im =-0.5 * (data[2*k] - data[N - 2*k]); - // A_k = C_k + 1^(k/N) D_k: - data[2*k] = Ck_re; // A_k <-- C_k - data[2*k+1] = Ck_im; - // now A_k += D_k 1^(k/N) - ComplexAddProduct(Dk_re, Dk_im, kN_re, kN_im, &(data[2*k]), &(data[2*k+1])); - - MatrixIndexT kdash = N2 - k; - if (kdash != k) { - // Next we handle the index k' = N/2 - k. This is necessary - // to do now, to avoid invalidating data that we will later need. - // The quantities C_{k'} and D_{k'} are just the conjugates of C_k - // and D_k, so the equations are simple modifications of the above, - // replacing Ck_im and Dk_im with their negatives. - data[2*kdash] = Ck_re; // A_k' <-- C_k' - data[2*kdash+1] = -Ck_im; - // now A_k' += D_k' 1^(k'/N) - // We use 1^(k'/N) = 1^((N/2 - k) / N) = 1^(1/2) 1^(-k/N) = -1 * (1^(k/N))^* - // so it's the same as 1^(k/N) but with the real part negated. - ComplexAddProduct(Dk_re, -Dk_im, -kN_re, kN_im, &(data[2*kdash]), &(data[2*kdash+1])); - } - } - - { // Now handle k = 0. - // In simple terms: after the complex fft, data[0] becomes the sum of real - // parts input[0], input[2]... and data[1] becomes the sum of imaginary - // pats input[1], input[3]... - // "zeroth" [A_0] is just the sum of input[0]+input[1]+input[2].. - // and "n2th" [A_{N/2}] is input[0]-input[1]+input[2]... . - Real zeroth = data[0] + data[1], - n2th = data[0] - data[1]; - data[0] = zeroth; - data[1] = n2th; - if (!forward) { - data[0] /= 2; - data[1] /= 2; - } - } - - if (!forward) { - ComplexFft(v, false); - v->Scale(2.0); // This is so we get a factor of N increase, rather than N/2 which we would - // otherwise get from [ComplexFft, forward] + [ComplexFft, backward] in dimension N/2. - // It's for consistency with our normal FFT convensions. - } -} - -template void RealFft (VectorBase *v, bool forward); -template void RealFft (VectorBase *v, bool forward); - -/* Notes for real FFTs. - We are using the same convention as above, 1^x to mean exp(-2\pi x) for the forward transform. - Actually, in a slight abuse of notation, we use this meaning for 1^x in both the forward and - backward cases because it's more convenient in this section. - - Suppose we have real data a[0...N-1], with N even, and want to compute its Fourier transform. - We can make do with the first N/2 points of the transform, since the remaining ones are complex - conjugates of the first. We want to compute: - for k = 0...N/2-1, - A_k = \sum_{n = 0}^{N-1} a_n 1^(kn/N) (1) - - We treat a[0..N-1] as a complex sequence of length N/2, i.e. a sequence b[0..N/2 - 1]. - Viewed as sequences of length N/2, we have: - b = c + i d, - where c = a_0, a_2 ... and d = a_1, a_3 ... - - We can recover the length-N/2 Fourier transforms of c and d by doing FT on b and - then doing the equations below. Derivation is marked by (*) in a comment below (search - for it). Let B, C, D be the FTs. - We have - C_k = 1/2 (B_k + B_{N/2 - k}^*) (z0) - D_k =-1/2i (B_k - B_{N/2 - k}^*) (z1) -so: re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})) (z2) - im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) (z3) - - To recover the FT A from C and D, we write, rearranging (1): - - A_k = \sum_{n = 0, 2, ..., N-2} a_n 1^(kn/N) - +\sum_{n = 1, 3, ..., N-1} a_n 1^(kn/N) - = \sum_{n = 0, 1, ..., N/2-1} a_n 1^(2kn/N) + a_{n+1} 1^(2kn/N) 1^(k/N) - = \sum_{n = 0, 1, ..., N/2-1} c_n 1^(2kn/N) + d_n 1^(2kn/N) 1^(k/N) - A_k = C_k + 1^(k/N) D_k (a0) - - This equation is valid for k = 0...N/2-1, which is the range of the sequences B_k and - C_k. We don't use is for k = 0, which is a special case considered below. For - 1 < k < N/2, it's convenient to consider the pair k, k', where k' = N/2 - k. - Remember that C_k' = C_k^ *and D_k' = D_k^* [where * is conjugation]. Also, - 1^(N/2 / N) = -1. So we have: - A_k' = C_k^* - 1^(k/N) D_k^* (a0b) - We do (a0) and (a0b) together. - - - - By symmetry this gives us the Fourier components for N/2+1, ... N, if we want - them. However, it doesn't give us the value for exactly k = N/2. For k = 0 and k = N/2, it - is easiest to argue directly about the meaning of the A_k, B_k and C_k in terms of - sums of points. - A_0 and A_{N/2} are both real, with A_0=\sum_n a_n, and A_1 an alternating sum - A_1 = a_0 - a_1 + a_2 ... - It's easy to show that - A_0 = B_0 + C_0 (a1) - A_{N/2} = B_0 - C_0. (a2) - Since B_0 and C_0 are both real, B_0 is the real coefficient of D_0 and C_0 is the - imaginary coefficient. - - *REVERSING THE PROCESS* - - Next we want to reverse this process. We just need to work out C_k and D_k from the - sequence A_k. Then we do the inverse complex fft and we get back where we started. - For 0 and N/2, working from (a1) and (a2) above, we can see that: - B_0 = 1/2 (A_0 + A_{N/2}) (y0) - C_0 = 1/2 (A_0 + A_{N/2}) (y1) - and we use - D_0 = B_0 + i C_0 - to get the 1st complex coefficient of D. This is exactly the same as the forward process - except with an extra factor of 1/2. - - Consider equations (a0) and (a0b). We want to work out C_k and D_k from A_k and A_k'. Remember - k' = N/2 - k. - - Write down - A_k = C_k + 1^(k/N) D_k (copying a0) - A_k'^* = C_k - 1^(k/N) D_k (conjugate of a0b) - So - C_k = 0.5 (A_k + A_k'^*) (p0) - D_k = 1^(-k/N) . 0.5 (A_k - A_k'^*) (p1) - Next, we want to compute B_k and B_k' from C_k and D_k. C.f. (z0)..(z3), and remember - that k' = N/2-k. We can see - that - B_k = C_k + i D_k (p2) - B_k' = C_k - i D_k (p3) - - We would like to make the equations (p0) ... (p3) look like the forward equations (z0), (z1), - (a0) and (a0b) so we can reuse the code. Define E_k = -i 1^(k/N) D_k. Then write down (p0)..(p3). - We have - C_k = 0.5 (A_k + A_k'^*) (p0') - E_k = -0.5 i (A_k - A_k'^*) (p1') - B_k = C_k - 1^(-k/N) E_k (p2') - B_k' = C_k + 1^(-k/N) E_k (p3') - So these are exactly the same as (z0), (z1), (a0), (a0b) except replacing 1^(k/N) with - -1^(-k/N) . Remember that we defined 1^x above to be exp(-2pi x/N), so the signs here - might be opposite to what you see in the code. - - MODIFICATION: we need to take care of a factor of two. The complex FFT we implemented - does not divide by N in the reverse case. So upon inversion we get larger by N/2. - However, this is not consistent with normal FFT conventions where you get a factor of N. - For this reason we multiply by two after the process described above. - -*/ - - -/* - (*) [this token is referred to in a comment above]. - - Notes for separating 2 real transforms from one complex one. Note that the - letters here (A, B, C and N) are all distinct from the same letters used in the - place where this comment is used. - Suppose we - have two sequences a_n and b_n, n = 0..N-1. We combine them into a complex - number, - c_n = a_n + i b_n. - Then we take the fourier transform to get - C_k = \sum_{n = 0}^{N-1} c_n 1^(n/N) . - Then we use symmetry. Define A_k and B_k as the DFTs of a and b. - We use A_k = A_{N-k}^*, and B_k = B_{N-k}^*, since a and b are real. Using - C_k = A_k + i B_k, - C_{N-k} = A_k^* + i B_k^* - = A_k^* - (i B_k)^* - So: - A_k = 1/2 (C_k + C_{N-k}^*) - i B_k = 1/2 (C_k - C_{N-k}^*) --> B_k =-1/2i (C_k - C_{N-k}^*) --> re(B_k) = 1/2 (im(C_k) + im(C_{N-k})) - im(B_k) =-1/2 (re(C_k) - re(C_{N-k})) - - */ - -template void ComputeDctMatrix(Matrix *M) { - //KALDI_ASSERT(M->NumRows() == M->NumCols()); - MatrixIndexT K = M->NumRows(); - MatrixIndexT N = M->NumCols(); - - KALDI_ASSERT(K > 0); - KALDI_ASSERT(N > 0); - Real normalizer = std::sqrt(1.0 / static_cast(N)); // normalizer for - // X_0. - for (MatrixIndexT j = 0; j < N; j++) (*M)(0, j) = normalizer; - normalizer = std::sqrt(2.0 / static_cast(N)); // normalizer for other - // elements. - for (MatrixIndexT k = 1; k < K; k++) - for (MatrixIndexT n = 0; n < N; n++) - (*M)(k, n) = normalizer - * std::cos( static_cast(M_PI)/N * (n + 0.5) * k ); -} - - -template void ComputeDctMatrix(Matrix *M); -template void ComputeDctMatrix(Matrix *M); - - -template -void ComputePca(const MatrixBase &X, - MatrixBase *U, - MatrixBase *A, - bool print_eigs, - bool exact) { - // Note that some of these matrices may be transposed w.r.t. the - // way it's most natural to describe them in math... it's the rows - // of X and U that correspond to the (data-points, basis elements). - MatrixIndexT N = X.NumRows(), D = X.NumCols(); - // N = #points, D = feature dim. - KALDI_ASSERT(U != NULL && U->NumCols() == D); - MatrixIndexT G = U->NumRows(); // # of retained basis elements. - KALDI_ASSERT(A == NULL || (A->NumRows() == N && A->NumCols() == G)); - KALDI_ASSERT(G <= N && G <= D); - if (D < N) { // Do conventional PCA. - SpMatrix Msp(D); // Matrix of outer products. - Msp.AddMat2(1.0, X, kTrans, 0.0); // M <-- X^T X - Matrix Utmp; - Vector l; - if (exact) { - Utmp.Resize(D, D); - l.Resize(D); - //Matrix M(Msp); - //M.DestructiveSvd(&l, &Utmp, NULL); - Msp.Eig(&l, &Utmp); - } else { - Utmp.Resize(D, G); - l.Resize(G); - Msp.TopEigs(&l, &Utmp); - } - SortSvd(&l, &Utmp); - - for (MatrixIndexT g = 0; g < G; g++) - U->Row(g).CopyColFromMat(Utmp, g); - if (print_eigs) - KALDI_LOG << (exact ? "" : "Retained ") - << "PCA eigenvalues are " << l; - if (A != NULL) - A->AddMatMat(1.0, X, kNoTrans, *U, kTrans, 0.0); - } else { // Do inner-product PCA. - SpMatrix Nsp(N); // Matrix of inner products. - Nsp.AddMat2(1.0, X, kNoTrans, 0.0); // M <-- X X^T - - Matrix Vtmp; - Vector l; - if (exact) { - Vtmp.Resize(N, N); - l.Resize(N); - Matrix Nmat(Nsp); - Nmat.DestructiveSvd(&l, &Vtmp, NULL); - } else { - Vtmp.Resize(N, G); - l.Resize(G); - Nsp.TopEigs(&l, &Vtmp); - } - - MatrixIndexT num_zeroed = 0; - for (MatrixIndexT g = 0; g < G; g++) { - if (l(g) < 0.0) { - KALDI_WARN << "In PCA, setting element " << l(g) << " to zero."; - l(g) = 0.0; - num_zeroed++; - } - } - SortSvd(&l, &Vtmp); // Make sure zero elements are last, this - // is necessary for Orthogonalize() to work properly later. - - Vtmp.Transpose(); // So eigenvalues are the rows. - - for (MatrixIndexT g = 0; g < G; g++) { - Real sqrtlg = sqrt(l(g)); - if (l(g) != 0.0) { - U->Row(g).AddMatVec(1.0 / sqrtlg, X, kTrans, Vtmp.Row(g), 0.0); - } else { - U->Row(g).SetZero(); - (*U)(g, g) = 1.0; // arbitrary direction. Will later orthogonalize. - } - if (A != NULL) - for (MatrixIndexT n = 0; n < N; n++) - (*A)(n, g) = sqrtlg * Vtmp(g, n); - } - // Now orthogonalize. This is mainly useful in - // case there were zero eigenvalues, but we do it - // for all of them. - U->OrthogonalizeRows(); - if (print_eigs) - KALDI_LOG << "(inner-product) PCA eigenvalues are " << l; - } -} - - -template -void ComputePca(const MatrixBase &X, - MatrixBase *U, - MatrixBase *A, - bool print_eigs, - bool exact); - -template -void ComputePca(const MatrixBase &X, - MatrixBase *U, - MatrixBase *A, - bool print_eigs, - bool exact); - - -// Added by Dan, Feb. 13 2012. -// This function does: *plus += max(0, a b^T), -// *minus += max(0, -(a b^T)). -template -void AddOuterProductPlusMinus(Real alpha, - const VectorBase &a, - const VectorBase &b, - MatrixBase *plus, - MatrixBase *minus) { - KALDI_ASSERT(a.Dim() == plus->NumRows() && b.Dim() == plus->NumCols() - && a.Dim() == minus->NumRows() && b.Dim() == minus->NumCols()); - int32 nrows = a.Dim(), ncols = b.Dim(), pskip = plus->Stride() - ncols, - mskip = minus->Stride() - ncols; - const Real *adata = a.Data(), *bdata = b.Data(); - Real *plusdata = plus->Data(), *minusdata = minus->Data(); - - for (int32 i = 0; i < nrows; i++) { - const Real *btmp = bdata; - Real multiple = alpha * *adata; - if (multiple > 0.0) { - for (int32 j = 0; j < ncols; j++, plusdata++, minusdata++, btmp++) { - if (*btmp > 0.0) *plusdata += multiple * *btmp; - else *minusdata -= multiple * *btmp; - } - } else { - for (int32 j = 0; j < ncols; j++, plusdata++, minusdata++, btmp++) { - if (*btmp < 0.0) *plusdata += multiple * *btmp; - else *minusdata -= multiple * *btmp; - } - } - plusdata += pskip; - minusdata += mskip; - adata++; - } -} - -// Instantiate template -template -void AddOuterProductPlusMinus(float alpha, - const VectorBase &a, - const VectorBase &b, - MatrixBase *plus, - MatrixBase *minus); -template -void AddOuterProductPlusMinus(double alpha, - const VectorBase &a, - const VectorBase &b, - MatrixBase *plus, - MatrixBase *minus); - - -} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/matrix-functions.h b/speechx/speechx/kaldi/matrix/matrix-functions.h deleted file mode 100644 index ca50ddda..00000000 --- a/speechx/speechx/kaldi/matrix/matrix-functions.h +++ /dev/null @@ -1,174 +0,0 @@ -// matrix/matrix-functions.h - -// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky; -// Yanmin Qian; 1991 Henrique (Rico) Malvar (*) -// -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// (*) incorporates, with permission, FFT code from his book -// "Signal Processing with Lapped Transforms", Artech, 1992. - - - -#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_H_ -#define KALDI_MATRIX_MATRIX_FUNCTIONS_H_ - -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - -/// @addtogroup matrix_funcs_misc -/// @{ - -/** The function ComplexFft does an Fft on the vector argument v. - v is a vector of even dimension, interpreted for both input - and output as a vector of complex numbers i.e. - \f[ v = ( re_0, im_0, re_1, im_1, ... ) \f] - - If "forward == true" this routine does the Discrete Fourier Transform - (DFT), i.e.: - \f[ vout[m] \leftarrow \sum_{n = 0}^{N-1} vin[i] exp( -2pi m n / N ) \f] - - If "backward" it does the Inverse Discrete Fourier Transform (IDFT) - *WITHOUT THE FACTOR 1/N*, - i.e.: - \f[ vout[m] <-- \sum_{n = 0}^{N-1} vin[i] exp( 2pi m n / N ) \f] - [note the sign difference on the 2 pi for the backward one.] - - Note that this is the definition of the FT given in most texts, but - it differs from the Numerical Recipes version in which the forward - and backward algorithms are flipped. - - Note that you would have to multiply by 1/N after the IDFT to get - back to where you started from. We don't do this because - in some contexts, the transform is made symmetric by multiplying - by sqrt(N) in both passes. The user can do this by themselves. - - See also SplitRadixComplexFft, declared in srfft.h, which is more efficient - but only works if the length of the input is a power of 2. - */ -template void ComplexFft (VectorBase *v, bool forward, Vector *tmp_work = NULL); - -/// ComplexFt is the same as ComplexFft but it implements the Fourier -/// transform in an inefficient way. It is mainly included for testing purposes. -/// See comment for ComplexFft to describe the input and outputs and what it does. -template void ComplexFt (const VectorBase &in, - VectorBase *out, bool forward); - -/// RealFft is a fourier transform of real inputs. Internally it uses -/// ComplexFft. The input dimension N must be even. If forward == true, -/// it transforms from a sequence of N real points to its complex fourier -/// transform; otherwise it goes in the reverse direction. If you call it -/// in the forward and then reverse direction and multiply by 1.0/N, you -/// will get back the original data. -/// The interpretation of the complex-FFT data is as follows: the array -/// is a sequence of complex numbers C_n of length N/2 with (real, im) format, -/// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. -/// See also SplitRadixRealFft, declared in srfft.h, which is more efficient -/// but only works if the length of the input is a power of 2. - -template void RealFft (VectorBase *v, bool forward); - - -/// RealFt has the same input and output format as RealFft above, but it is -/// an inefficient implementation included for testing purposes. -template void RealFftInefficient (VectorBase *v, bool forward); - -/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that -/// M * v equals the DCT of vector v. M must be square at input. -/// This is the type = III DCT with normalization, corresponding to the -/// following equations, where x is the signal and X is the DCT: -/// X_0 = 1/sqrt(2*N) \sum_{n = 0}^{N-1} x_n -/// X_k = 1/sqrt(N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k ) -/// This matrix's transpose is its own inverse, so transposing this -/// matrix will give the inverse DCT. -/// Caution: the type III DCT is generally known as the "inverse DCT" (with the -/// type II being the actual DCT), so this function is somewhatd mis-named. It -/// was probably done this way for HTK compatibility. We don't change it -/// because it was this way from the start and changing it would affect the -/// feature generation. - -template void ComputeDctMatrix(Matrix *M); - - -/// ComplexMul implements, inline, the complex multiplication b *= a. -template inline void ComplexMul(const Real &a_re, const Real &a_im, - Real *b_re, Real *b_im); - -/// ComplexMul implements, inline, the complex operation c += (a * b). -template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, - const Real &b_re, const Real &b_im, - Real *c_re, Real *c_im); - - -/// ComplexImExp implements a <-- exp(i x), inline. -template inline void ComplexImExp(Real x, Real *a_re, Real *a_im); - - - -/** - ComputePCA does a PCA computation, using either outer products - or inner products, whichever is more efficient. Let D be - the dimension of the data points, N be the number of data - points, and G be the PCA dimension we want to retain. We assume - G <= N and G <= D. - - @param X [in] An N x D matrix. Each row of X is a point x_i. - @param U [out] A G x D matrix. Each row of U is a basis element u_i. - @param A [out] An N x D matrix, or NULL. Each row of A is a set of coefficients - in the basis for a point x_i, so A(i, g) is the coefficient of u_i - in x_i. - @param print_eigs [in] If true, prints out diagnostic information about the - eigenvalues. - @param exact [in] If true, does the exact computation; if false, does - a much faster (but almost exact) computation based on the Lanczos - method. -*/ - -template -void ComputePca(const MatrixBase &X, - MatrixBase *U, - MatrixBase *A, - bool print_eigs = false, - bool exact = true); - - - -// This function does: *plus += max(0, a b^T), -// *minus += max(0, -(a b^T)). -template -void AddOuterProductPlusMinus(Real alpha, - const VectorBase &a, - const VectorBase &b, - MatrixBase *plus, - MatrixBase *minus); - -template -inline void AssertSameDim(const MatrixBase &mat1, const MatrixBase &mat2) { - KALDI_ASSERT(mat1.NumRows() == mat2.NumRows() - && mat1.NumCols() == mat2.NumCols()); -} - - -/// @} end of "addtogroup matrix_funcs_misc" - -} // end namespace kaldi - -#include "matrix/matrix-functions-inl.h" - - -#endif diff --git a/speechx/speechx/kaldi/matrix/matrix-lib.h b/speechx/speechx/kaldi/matrix/matrix-lib.h deleted file mode 100644 index 2a5ebad7..00000000 --- a/speechx/speechx/kaldi/matrix/matrix-lib.h +++ /dev/null @@ -1,37 +0,0 @@ -// matrix/matrix-lib.h - -// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -// Include everything from this directory. -// These files include other stuff that we need. -#ifndef KALDI_MATRIX_MATRIX_LIB_H_ -#define KALDI_MATRIX_MATRIX_LIB_H_ - -#include "base/kaldi-common.h" -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/tp-matrix.h" -#include "matrix/matrix-functions.h" -#include "matrix/srfft.h" -#include "matrix/compressed-matrix.h" -#include "matrix/sparse-matrix.h" -#include "matrix/optimization.h" - -#endif - diff --git a/speechx/speechx/kaldi/matrix/optimization.cc b/speechx/speechx/kaldi/matrix/optimization.cc deleted file mode 100644 index c17b5b94..00000000 --- a/speechx/speechx/kaldi/matrix/optimization.cc +++ /dev/null @@ -1,577 +0,0 @@ -// matrix/optimization.cc - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey) - - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// (*) incorporates, with permission, FFT code from his book -// "Signal Processing with Lapped Transforms", Artech, 1992. - -#include - -#include "matrix/optimization.h" -#include "matrix/sp-matrix.h" - -namespace kaldi { - - -// Below, N&W refers to Nocedal and Wright, "Numerical Optimization", 2nd Ed. - -template -OptimizeLbfgs::OptimizeLbfgs(const VectorBase &x, - const LbfgsOptions &opts): - opts_(opts), k_(0), computation_state_(kBeforeStep), H_was_set_(false) { - KALDI_ASSERT(opts.m > 0); // dimension. - MatrixIndexT dim = x.Dim(); - KALDI_ASSERT(dim > 0); - x_ = x; // this is the value of x_k - new_x_ = x; // this is where we'll evaluate the function next. - deriv_.Resize(dim); - temp_.Resize(dim); - data_.Resize(2 * opts.m, dim); - rho_.Resize(opts.m); - // Just set f_ to some invalid value, as we haven't yet set it. - f_ = (opts.minimize ? 1 : -1 ) * std::numeric_limits::infinity(); - best_f_ = f_; - best_x_ = x_; -} - - -template -Real OptimizeLbfgs::RecentStepLength() const { - size_t n = step_lengths_.size(); - if (n == 0) return std::numeric_limits::infinity(); - else { - if (n >= 2 && step_lengths_[n-1] == 0.0 && step_lengths_[n-2] == 0.0) - return 0.0; // two zeros in a row means repeated restarts, which is - // a loop. Short-circuit this by returning zero. - Real avg = 0.0; - for (size_t i = 0; i < n; i++) - avg += step_lengths_[i] / n; - return avg; - } -} - -template -void OptimizeLbfgs::ComputeHifNeeded(const VectorBase &gradient) { - if (k_ == 0) { - if (H_.Dim() == 0) { - // H was never set up. Set it up for the first time. - Real learning_rate; - if (opts_.first_step_length > 0.0) { // this takes - // precedence over first_step_learning_rate, if set. - // We are setting up H for the first time. - Real gradient_length = gradient.Norm(2.0); - learning_rate = (gradient_length > 0.0 ? - opts_.first_step_length / gradient_length : - 1.0); - } else if (opts_.first_step_impr > 0.0) { - Real gradient_length = gradient.Norm(2.0); - learning_rate = (gradient_length > 0.0 ? - opts_.first_step_impr / (gradient_length * gradient_length) : - 1.0); - } else { - learning_rate = opts_.first_step_learning_rate; - } - H_.Resize(x_.Dim()); - KALDI_ASSERT(learning_rate > 0.0); - H_.Set(opts_.minimize ? learning_rate : -learning_rate); - } - } else { // k_ > 0 - if (!H_was_set_) { // The user never specified an approximate - // diagonal inverse Hessian. - // Set it using formula 7.20: H_k^{(0)} = \gamma_k I, where - // \gamma_k = s_{k-1}^T y_{k-1} / y_{k-1}^T y_{k-1} - SubVector y_km1 = Y(k_-1); - double gamma_k = VecVec(S(k_-1), y_km1) / VecVec(y_km1, y_km1); - if (KALDI_ISNAN(gamma_k) || KALDI_ISINF(gamma_k)) { - KALDI_WARN << "NaN encountered in L-BFGS (already converged?)"; - gamma_k = (opts_.minimize ? 1.0 : -1.0); - } - H_.Set(gamma_k); - } - } -} - -// This represents the first 2 lines of Algorithm 7.5 (N&W), which -// in fact is mostly a call to Algorithm 7.4. -// Note: this is valid whether we are minimizing or maximizing. -template -void OptimizeLbfgs::ComputeNewDirection(Real function_value, - const VectorBase &gradient) { - KALDI_ASSERT(computation_state_ == kBeforeStep); - SignedMatrixIndexT m = M(), k = k_; - ComputeHifNeeded(gradient); - // The rest of this is computing p_k <-- - H_k \nabla f_k using Algorithm - // 7.4 of N&W. - Vector &q(deriv_), &r(new_x_); // Use deriv_ as a temporary place to put - // q, and new_x_ as a temporay place to put r. - // The if-statement below is just to get rid of spurious warnings from - // valgrind about memcpy source and destination overlap, since sometimes q and - // gradient are the same variable. - if (&q != &gradient) - q.CopyFromVec(gradient); // q <-- \nabla f_k. - Vector alpha(m); - // for i = k - 1, k - 2, ... k - m - for (SignedMatrixIndexT i = k - 1; - i >= std::max(k - m, static_cast(0)); - i--) { - alpha(i % m) = rho_(i % m) * VecVec(S(i), q); // \alpha_i <-- \rho_i s_i^T q. - q.AddVec(-alpha(i % m), Y(i)); // q <-- q - \alpha_i y_i - } - r.SetZero(); - r.AddVecVec(1.0, H_, q, 0.0); // r <-- H_k^{(0)} q. - // for k = k - m, k - m + 1, ... , k - 1 - for (SignedMatrixIndexT i = std::max(k - m, static_cast(0)); - i < k; - i++) { - Real beta = rho_(i % m) * VecVec(Y(i), r); // \beta <-- \rho_i y_i^T r - r.AddVec(alpha(i % m) - beta, S(i)); // r <-- r + s_i (\alpha_i - \beta) - } - - { // TEST. Note, -r will be the direction. - Real dot = VecVec(gradient, r); - if ((opts_.minimize && dot < 0) || (!opts_.minimize && dot > 0)) - KALDI_WARN << "Step direction has the wrong sign! Routine will fail."; - } - - // Now we're out of Alg. 7.4 and back into Alg. 7.5. - // Alg. 7.4 returned r (using new_x_ as the location), and with \alpha_k = 1 - // as the initial guess, we're setting x_{k+1} = x_k + \alpha_k p_k, with - // p_k = -r [hence the statement new_x_.Scale(-1.0)]., and \alpha_k = 1. - // This is the first place we'll get the user to evaluate the function; - // any backtracking (or acceptance of that step) occurs inside StepSizeIteration. - // We're still within iteration k; we haven't yet finalized the step size. - new_x_.Scale(-1.0); - new_x_.AddVec(1.0, x_); - if (&deriv_ != &gradient) - deriv_.CopyFromVec(gradient); - f_ = function_value; - d_ = opts_.d; - num_wolfe_i_failures_ = 0; - num_wolfe_ii_failures_ = 0; - last_failure_type_ = kNone; - computation_state_ = kWithinStep; -} - - -template -bool OptimizeLbfgs::AcceptStep(Real function_value, - const VectorBase &gradient) { - // Save s_k = x_{k+1} - x_{k}, and y_k = \nabla f_{k+1} - \nabla f_k. - SubVector s = S(k_), y = Y(k_); - s.CopyFromVec(new_x_); - s.AddVec(-1.0, x_); // s = new_x_ - x_. - y.CopyFromVec(gradient); - y.AddVec(-1.0, deriv_); // y = gradient - deriv_. - - // Warning: there is a division in the next line. This could - // generate inf or nan, but this wouldn't necessarily be an error - // at this point because for zero step size or derivative we should - // terminate the iterations. But this is up to the calling code. - Real prod = VecVec(y, s); - rho_(k_ % opts_.m) = 1.0 / prod; - Real len = s.Norm(2.0); - - if ((opts_.minimize && prod <= 1.0e-20) || (!opts_.minimize && prod >= -1.0e-20) - || len == 0.0) - return false; // This will force restart. - - KALDI_VLOG(3) << "Accepted step; length was " << len - << ", prod was " << prod; - RecordStepLength(len); - - // store x_{k+1} and the function value f_{k+1}. - x_.CopyFromVec(new_x_); - f_ = function_value; - k_++; - - return true; // We successfully accepted the step. -} - -template -void OptimizeLbfgs::RecordStepLength(Real s) { - step_lengths_.push_back(s); - if (step_lengths_.size() > static_cast(opts_.avg_step_length)) - step_lengths_.erase(step_lengths_.begin(), step_lengths_.begin() + 1); -} - - -template -void OptimizeLbfgs::Restart(const VectorBase &x, - Real f, - const VectorBase &gradient) { - // Note: we will consider restarting (the transition of x_ -> x) - // as a step, even if it has zero step size. This is necessary in - // order for convergence to be detected. - { - Vector &diff(temp_); - diff.CopyFromVec(x); - diff.AddVec(-1.0, x_); - RecordStepLength(diff.Norm(2.0)); - } - k_ = 0; // Restart the iterations! [But note that the Hessian, - // whatever it was, stays as before.] - if (&x_ != &x) - x_.CopyFromVec(x); - new_x_.CopyFromVec(x); - f_ = f; - computation_state_ = kBeforeStep; - ComputeNewDirection(f, gradient); -} - -template -void OptimizeLbfgs::StepSizeIteration(Real function_value, - const VectorBase &gradient) { - KALDI_VLOG(3) << "In step size iteration, function value changed " - << f_ << " to " << function_value; - - // We're in some part of the backtracking, and the user is providing - // the objective function value and gradient. - // We're checking two conditions: Wolfe i) [the Armijo rule] and - // Wolfe ii). - - // The Armijo rule (when minimizing) is: - // f(k_k + \alpha_k p_k) <= f(x_k) + c_1 \alpha_k p_k^T \nabla f(x_k), where - // \nabla means the derivative. - // Below, "temp" is the RHS of this equation, where (\alpha_k p_k) equals - // (new_x_ - x_); we don't store \alpha or p_k separately, they are implicit - // as the difference new_x_ - x_. - - // Below, pf is \alpha_k p_k^T \nabla f(x_k). - Real pf = VecVec(new_x_, deriv_) - VecVec(x_, deriv_); - Real temp = f_ + opts_.c1 * pf; - - bool wolfe_i_ok; - if (opts_.minimize) wolfe_i_ok = (function_value <= temp); - else wolfe_i_ok = (function_value >= temp); - - // Wolfe condition ii) can be written as: - // p_k^T \nabla f(x_k + \alpha_k p_k) >= c_2 p_k^T \nabla f(x_k) - // p2f equals \alpha_k p_k^T \nabla f(x_k + \alpha_k p_k), where - // (\alpha_k p_k^T) is (new_x_ - x_). - // Note that in our version of Wolfe condition (ii) we have an extra - // factor alpha, which doesn't affect anything. - Real p2f = VecVec(new_x_, gradient) - VecVec(x_, gradient); - //eps = (sizeof(Real) == 4 ? 1.0e-05 : 1.0e-10) * - //(std::abs(p2f) + std::abs(pf)); - bool wolfe_ii_ok; - if (opts_.minimize) wolfe_ii_ok = (p2f >= opts_.c2 * pf); - else wolfe_ii_ok = (p2f <= opts_.c2 * pf); - - enum { kDecrease, kNoChange } d_action; // What do do with d_: leave it alone, - // or take the square root. - enum { kAccept, kDecreaseStep, kIncreaseStep, kRestart } iteration_action; - // What we'll do in the overall iteration: accept this value, DecreaseStep - // (reduce the step size), IncreaseStep (increase the step size), or kRestart - // (set k back to zero). Generally when we can't get both conditions to be - // true with a reasonable period of time, it makes sense to restart, because - // probably we've almost converged and got into numerical issues; from here - // we'll just produced NaN's. Restarting is a safe thing to do and the outer - // code will quickly detect convergence. - - d_action = kNoChange; // the default. - - if (wolfe_i_ok && wolfe_ii_ok) { - iteration_action = kAccept; - d_action = kNoChange; // actually doesn't matter, it'll get reset. - } else if (!wolfe_i_ok) { - // If wolfe i) [the Armijo rule] failed then we went too far (or are - // meeting numerical problems). - if (last_failure_type_ == kWolfeII) { // Last time we failed it was Wolfe ii). - // When we switch between them we decrease d. - d_action = kDecrease; - } - iteration_action = kDecreaseStep; - last_failure_type_ = kWolfeI; - num_wolfe_i_failures_++; - } else if (!wolfe_ii_ok) { - // Curvature condition failed -> we did not go far enough. - if (last_failure_type_ == kWolfeI) // switching between wolfe i and ii failures-> - d_action = kDecrease; // decrease value of d. - iteration_action = kIncreaseStep; - last_failure_type_ = kWolfeII; - num_wolfe_ii_failures_++; - } - - // Test whether we've been switching too many times betwen wolfe i) and ii) - // failures, or overall have an excessive number of failures. We just give up - // and restart L-BFGS. Probably we've almost converged. - if (num_wolfe_i_failures_ + num_wolfe_ii_failures_ > - opts_.max_line_search_iters) { - KALDI_VLOG(2) << "Too many steps in line search -> restarting."; - iteration_action = kRestart; - } - - if (d_action == kDecrease) - d_ = std::sqrt(d_); - - KALDI_VLOG(3) << "d = " << d_ << ", iter = " << k_ << ", action = " - << (iteration_action == kAccept ? "accept" : - (iteration_action == kDecreaseStep ? "decrease" : - (iteration_action == kIncreaseStep ? "increase" : - "reject"))); - - // Note: even if iteration_action != Restart at this point, - // some code below may set it to Restart. - if (iteration_action == kAccept) { - if (AcceptStep(function_value, gradient)) { // If we did - // not detect a problem while accepting the step.. - computation_state_ = kBeforeStep; - ComputeNewDirection(function_value, gradient); - } else { - KALDI_VLOG(2) << "Restarting L-BFGS computation; problem found while " - << "accepting step."; - iteration_action = kRestart; // We'll have to restart now. - } - } - if (iteration_action == kDecreaseStep || iteration_action == kIncreaseStep) { - Real scale = (iteration_action == kDecreaseStep ? 1.0 / d_ : d_); - temp_.CopyFromVec(new_x_); - new_x_.Scale(scale); - new_x_.AddVec(1.0 - scale, x_); - if (new_x_.ApproxEqual(temp_, 0.0)) { - // Value of new_x_ did not change at all --> we must restart. - KALDI_VLOG(3) << "Value of x did not change, when taking step; " - << "will restart computation."; - iteration_action = kRestart; - } - if (new_x_.ApproxEqual(temp_, 1.0e-08) && - std::abs(f_ - function_value) < 1.0e-08 * - std::abs(f_) && iteration_action == kDecreaseStep) { - // This is common and due to roundoff. - KALDI_VLOG(3) << "We appear to be backtracking while we are extremely " - << "close to the old value; restarting."; - iteration_action = kRestart; - } - - if (iteration_action == kDecreaseStep) { - num_wolfe_i_failures_++; - last_failure_type_ = kWolfeI; - } else { - num_wolfe_ii_failures_++; - last_failure_type_ = kWolfeII; - } - } - if (iteration_action == kRestart) { - // We want to restart the computation. If the objf at new_x_ is - // better than it was at x_, we'll start at new_x_, else at x_. - bool use_newx; - if (opts_.minimize) use_newx = (function_value < f_); - else use_newx = (function_value > f_); - KALDI_VLOG(3) << "Restarting computation."; - if (use_newx) Restart(new_x_, function_value, gradient); - else Restart(x_, f_, deriv_); - } -} - -template -void OptimizeLbfgs::DoStep(Real function_value, - const VectorBase &gradient) { - if (opts_.minimize ? function_value < best_f_ : function_value > best_f_) { - best_f_ = function_value; - best_x_.CopyFromVec(new_x_); - } - if (computation_state_ == kBeforeStep) - ComputeNewDirection(function_value, gradient); - else // kWithinStep{1,2,3} - StepSizeIteration(function_value, gradient); -} - -template -void OptimizeLbfgs::DoStep(Real function_value, - const VectorBase &gradient, - const VectorBase &diag_approx_2nd_deriv) { - if (opts_.minimize ? function_value < best_f_ : function_value > best_f_) { - best_f_ = function_value; - best_x_.CopyFromVec(new_x_); - } - if (opts_.minimize) { - KALDI_ASSERT(diag_approx_2nd_deriv.Min() > 0.0); - } else { - KALDI_ASSERT(diag_approx_2nd_deriv.Max() < 0.0); - } - H_was_set_ = true; - H_.CopyFromVec(diag_approx_2nd_deriv); - H_.InvertElements(); - DoStep(function_value, gradient); -} - -template -const VectorBase& -OptimizeLbfgs::GetValue(Real *objf_value) const { - if (objf_value != NULL) *objf_value = best_f_; - return best_x_; -} - -// to compute the alpha, we are minimizing f(x) = x^T b - 0.5 x_k^T A x_k along -// direction p_k... consider alpha -// d/dx of f(x) = b - A x_k = r. - -// Notation based on Sec. 5.1 of Nocedal and Wright -// Computation based on Alg. 5.2 of Nocedal and Wright (Pg. 112) -// Notation (replicated for convenience): -// To solve Ax=b for x -// k : current iteration -// x_k : estimate of x (at iteration k) -// r_k : residual ( r_k \eqdef A x_k - b ) -// \alpha_k : step size -// p_k : A-conjugate direction -// \beta_k : coefficient used in A-conjugate direction computation for next -// iteration -// -// Algo. LinearCG(A,b,x_0) -// ======================== -// r_0 = Ax_0 - b -// p_0 = -r_0 -// k = 0 -// -// while r_k != 0 -// \alpha_k = (r_k^T r_k) / (p_k^T A p_k) -// x_{k+1} = x_k + \alpha_k p_k; -// r_{k+1} = r_k + \alpha_k A p_k -// \beta_{k+1} = \frac{r_{k+1}^T r_{k+1}}{r_k^T r_K} -// p_{k+1} = -r_{k+1} + \beta_{k+1} p_k -// k = k + 1 -// end - -template -int32 LinearCgd(const LinearCgdOptions &opts, - const SpMatrix &A, - const VectorBase &b, - VectorBase *x) { - // Initialize the variables - // - int32 M = A.NumCols(); - - Matrix storage(4, M); - SubVector r(storage, 0), p(storage, 1), Ap(storage, 2), x_orig(storage, 3); - p.CopyFromVec(b); - p.AddSpVec(-1.0, A, *x, 1.0); // p_0 = b - A x_0 - r.AddVec(-1.0, p); // r_0 = - p_0 - x_orig.CopyFromVec(*x); // in case of failure. - - Real r_cur_norm_sq = VecVec(r, r), - r_initial_norm_sq = r_cur_norm_sq, - r_recompute_norm_sq = r_cur_norm_sq; - - KALDI_VLOG(5) << "In linear CG: initial norm-square of residual = " - << r_initial_norm_sq; - - KALDI_ASSERT(opts.recompute_residual_factor <= 1.0); - Real max_error_sq = std::max(opts.max_error * opts.max_error, - std::numeric_limits::min()), - residual_factor = opts.recompute_residual_factor * - opts.recompute_residual_factor, - inv_residual_factor = 1.0 / residual_factor; - - // Note: although from a mathematical point of view the method should converge - // after M iterations, in practice (due to roundoff) it does not always - // converge to good precision after that many iterations so we let the maximum - // be M + 5 instead. - int32 k = 0; - for (; k < M + 5 && k != opts.max_iters; k++) { - // Note: we'll break from this loop if we converge sooner due to - // max_error. - Ap.AddSpVec(1.0, A, p, 0.0); // Ap = A p - - // Below is how the code used to look. - // // next line: \alpha_k = (r_k^T r_k) / (p_k^T A p_k) - // Real alpha = r_cur_norm_sq / VecVec(p, Ap); - // - // We changed r_cur_norm_sq below to -VecVec(p, r). Although this is - // slightly less efficient, it seems to make the algorithm dramatically more - // robust. Note that -p^T r is the mathematically more natural quantity to - // use here, that corresponds to minimizing along that direction... r^T r is - // recommended in Nocedal and Wright only as a kind of optimization as it is - // supposed to be the same as -p^T r and we already have it computed. - Real alpha = -VecVec(p, r) / VecVec(p, Ap); - - // next line: x_{k+1} = x_k + \alpha_k p_k; - x->AddVec(alpha, p); - // next line: r_{k+1} = r_k + \alpha_k A p_k - r.AddVec(alpha, Ap); - Real r_next_norm_sq = VecVec(r, r); - - if (r_next_norm_sq < residual_factor * r_recompute_norm_sq || - r_next_norm_sq > inv_residual_factor * r_recompute_norm_sq) { - - // Recompute the residual from scratch if the residual norm has decreased - // a lot; this costs an extra matrix-vector multiply, but helps keep the - // residual accurate. - // Also do the same if the residual norm has increased a lot since - // the last time we recomputed... this shouldn't happen often, but - // it can indicate bad stuff is happening. - - // r_{k+1} = A x_{k+1} - b - r.AddSpVec(1.0, A, *x, 0.0); - r.AddVec(-1.0, b); - r_next_norm_sq = VecVec(r, r); - r_recompute_norm_sq = r_next_norm_sq; - - KALDI_VLOG(5) << "In linear CG: recomputing residual."; - } - KALDI_VLOG(5) << "In linear CG: k = " << k - << ", r_next_norm_sq = " << r_next_norm_sq; - // Check if converged. - if (r_next_norm_sq <= max_error_sq) - break; - - // next line: \beta_{k+1} = \frac{r_{k+1}^T r_{k+1}}{r_k^T r_K} - Real beta_next = r_next_norm_sq / r_cur_norm_sq; - // next lines: p_{k+1} = -r_{k+1} + \beta_{k+1} p_k - Vector p_old(p); - p.Scale(beta_next); - p.AddVec(-1.0, r); - r_cur_norm_sq = r_next_norm_sq; - } - - // note: the first element of the && is only there to save compute. - // the residual r is A x - b, and r_cur_norm_sq and r_initial_norm_sq are - // of the form r * r, so it's clear that b * b has the right dimension to - // compare with the residual. - if (r_cur_norm_sq > r_initial_norm_sq && - r_cur_norm_sq > r_initial_norm_sq + 1.0e-10 * VecVec(b, b)) { - KALDI_WARN << "Doing linear CGD in dimension " << A.NumRows() << ", after " << k - << " iterations the squared residual has got worse, " - << r_cur_norm_sq << " > " << r_initial_norm_sq - << ". Will do an exact optimization."; - SolverOptions opts("called-from-linearCGD"); - x->CopyFromVec(x_orig); - SolveQuadraticProblem(A, b, opts, x); - } - return k; -} - -// Instantiate the class for float and double. -template -class OptimizeLbfgs; -template -class OptimizeLbfgs; - - -template -int32 LinearCgd(const LinearCgdOptions &opts, - const SpMatrix &A, const VectorBase &b, - VectorBase *x); - -template -int32 LinearCgd(const LinearCgdOptions &opts, - const SpMatrix &A, const VectorBase &b, - VectorBase *x); - -} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/optimization.h b/speechx/speechx/kaldi/matrix/optimization.h deleted file mode 100644 index 66309aca..00000000 --- a/speechx/speechx/kaldi/matrix/optimization.h +++ /dev/null @@ -1,248 +0,0 @@ -// matrix/optimization.h - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey) -// -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// (*) incorporates, with permission, FFT code from his book -// "Signal Processing with Lapped Transforms", Artech, 1992. - - - -#ifndef KALDI_MATRIX_OPTIMIZATION_H_ -#define KALDI_MATRIX_OPTIMIZATION_H_ - -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - - -/// @addtogroup matrix_optimization -/// @{ - -struct LinearCgdOptions { - int32 max_iters; // Maximum number of iters (if >= 0). - BaseFloat max_error; // Maximum 2-norm of the residual A x - b (convergence - // test) - // Every time the residual 2-norm decreases by this recompute_residual_factor - // since the last time it was computed from scratch, recompute it from - // scratch. This helps to keep the computed residual accurate even in the - // presence of roundoff. - BaseFloat recompute_residual_factor; - - LinearCgdOptions(): max_iters(-1), - max_error(0.0), - recompute_residual_factor(0.01) { } -}; - -/* - This function uses linear conjugate gradient descent to approximately solve - the system A x = b. The value of x at entry corresponds to the initial guess - of x. The algorithm continues until the number of iterations equals b.Dim(), - or until the 2-norm of (A x - b) is <= max_error, or until the number of - iterations equals max_iter, whichever happens sooner. It is a requirement - that A be positive definite. - It returns the number of iterations that were actually executed (this is - useful for testing purposes). -*/ -template -int32 LinearCgd(const LinearCgdOptions &opts, - const SpMatrix &A, const VectorBase &b, - VectorBase *x); - - - - - - -/** - This is an implementation of L-BFGS. It pushes responsibility for - determining when to stop, onto the user. There is no call-back here: - everything is done via calls to the class itself (see the example in - matrix-lib-test.cc). This does not implement constrained L-BFGS, but it will - handle constrained problems correctly as long as the function approaches - +infinity (or -infinity for maximization problems) when it gets close to the - bound of the constraint. In these types of problems, you just let the - function value be +infinity for minimization problems, or -infinity for - maximization problems, outside these bounds). -*/ - -struct LbfgsOptions { - bool minimize; // if true, we're minimizing, else maximizing. - int m; // m is the number of stored vectors L-BFGS keeps. - float first_step_learning_rate; // The very first step of L-BFGS is - // like gradient descent. If you want to configure the size of that step, - // you can do it using this variable. - float first_step_length; // If this variable is >0.0, it overrides - // first_step_learning_rate; on the first step we choose an approximate - // Hessian that is the multiple of the identity that would generate this - // step-length, or 1.0 if the gradient is zero. - float first_step_impr; // If this variable is >0.0, it overrides - // first_step_learning_rate; on the first step we choose an approximate - // Hessian that is the multiple of the identity that would generate this - // amount of objective function improvement (assuming the "real" objf - // was linear). - float c1; // A constant in Armijo rule = Wolfe condition i) - float c2; // A constant in Wolfe condition ii) - float d; // An amount > 1.0 (default 2.0) that we initially multiply or - // divide the step length by, in the line search. - int max_line_search_iters; // after this many iters we restart L-BFGS. - int avg_step_length; // number of iters to avg step length over, in - // RecentStepLength(). - - LbfgsOptions (bool minimize = true): - minimize(minimize), - m(10), - first_step_learning_rate(1.0), - first_step_length(0.0), - first_step_impr(0.0), - c1(1.0e-04), - c2(0.9), - d(2.0), - max_line_search_iters(50), - avg_step_length(4) { } -}; - -template -class OptimizeLbfgs { - public: - /// Initializer takes the starting value of x. - OptimizeLbfgs(const VectorBase &x, - const LbfgsOptions &opts); - - /// This returns the value of the variable x that has the best objective - /// function so far, and the corresponding objective function value if - /// requested. This would typically be called only at the end. - const VectorBase& GetValue(Real *objf_value = NULL) const; - - /// This returns the value at which the function wants us - /// to compute the objective function and gradient. - const VectorBase& GetProposedValue() const { return new_x_; } - - /// Returns the average magnitude of the last n steps (but not - /// more than the number we have stored). Before we have taken - /// any steps, returns +infinity. Note: if the most recent - /// step length was 0, it returns 0, regardless of the other - /// step lengths. This makes it suitable as a convergence test - /// (else we'd generate NaN's). - Real RecentStepLength() const; - - /// The user calls this function to provide the class with the - /// function and gradient info at the point GetProposedValue(). - /// If this point is outside the constraints you can set function_value - /// to {+infinity,-infinity} for {minimization,maximization} problems. - /// In this case the gradient, and also the second derivative (if you call - /// the second overloaded version of this function) will be ignored. - void DoStep(Real function_value, - const VectorBase &gradient); - - /// The user can call this version of DoStep() if it is desired to set some - /// kind of approximate Hessian on this iteration. Note: it is a prerequisite - /// that diag_approx_2nd_deriv must be strictly positive (minimizing), or - /// negative (maximizing). - void DoStep(Real function_value, - const VectorBase &gradient, - const VectorBase &diag_approx_2nd_deriv); - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(OptimizeLbfgs); - - - // The following variable says what stage of the computation we're at. - // Refer to Algorithm 7.5 (L-BFGS) of Nodecdal & Wright, "Numerical - // Optimization", 2nd edition. - // kBeforeStep means we're about to do - /// "compute p_k <-- - H_k \delta f_k" (i.e. Algorithm 7.4). - // kWithinStep means we're at some point within line search; note - // that line search is iterative so we can stay in this state more - // than one time on each iteration. - enum ComputationState { - kBeforeStep, - kWithinStep, // This means we're within the step-size computation, and - // have not yet done the 1st function evaluation. - }; - - inline MatrixIndexT Dim() { return x_.Dim(); } - inline MatrixIndexT M() { return opts_.m; } - SubVector Y(MatrixIndexT i) { - return SubVector(data_, (i % M()) * 2); // vector y_i - } - SubVector S(MatrixIndexT i) { - return SubVector(data_, (i % M()) * 2 + 1); // vector s_i - } - // The following are subroutines within DoStep(): - bool AcceptStep(Real function_value, - const VectorBase &gradient); - void Restart(const VectorBase &x, - Real function_value, - const VectorBase &gradient); - void ComputeNewDirection(Real function_value, - const VectorBase &gradient); - void ComputeHifNeeded(const VectorBase &gradient); - void StepSizeIteration(Real function_value, - const VectorBase &gradient); - void RecordStepLength(Real s); - - - LbfgsOptions opts_; - SignedMatrixIndexT k_; // Iteration number, starts from zero. Gets set back to zero - // when we restart. - - ComputationState computation_state_; - bool H_was_set_; // True if the user specified H_; if false, - // we'll use a heuristic to estimate it. - - - Vector x_; // current x. - Vector new_x_; // the x proposed in the line search. - Vector best_x_; // the x with the best objective function so far - // (either the same as x_ or something in the current line search.) - Vector deriv_; // The most recently evaluated derivative-- at x_k. - Vector temp_; - Real f_; // The function evaluated at x_k. - Real best_f_; // the best objective function so far. - Real d_; // a number d > 1.0, but during an iteration we may decrease this, when - // we switch between armijo and wolfe failures. - - int num_wolfe_i_failures_; // the num times we decreased step size. - int num_wolfe_ii_failures_; // the num times we increased step size. - enum { kWolfeI, kWolfeII, kNone } last_failure_type_; // last type of step-search - // failure on this iter. - - Vector H_; // Current inverse-Hessian estimate. May be computed by this class itself, - // or provided by user using 2nd form of SetGradientInfo(). - Matrix data_; // dimension (m*2) x dim. Even rows store - // gradients y_i, odd rows store steps s_i. - Vector rho_; // dimension m; rho_(m) = 1/(y_m^T s_m), Eq. 7.17. - - std::vector step_lengths_; // The step sizes we took on the last - // (up to m) iterations; these are not stored in a rotating buffer but - // are shifted by one each time (this is more convenient when we - // restart, as we keep this info past restarting). - - -}; - -/// @} - - -} // end namespace kaldi - - - -#endif - diff --git a/speechx/speechx/kaldi/matrix/packed-matrix.cc b/speechx/speechx/kaldi/matrix/packed-matrix.cc deleted file mode 100644 index 80bf5891..00000000 --- a/speechx/speechx/kaldi/matrix/packed-matrix.cc +++ /dev/null @@ -1,438 +0,0 @@ -// matrix/packed-matrix.cc - -// Copyright 2009-2012 Microsoft Corporation Saarland University -// Johns Hopkins University (Author: Daniel Povey); -// Haihua Xu - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -/** - * @file packed-matrix.cc - * - * Implementation of specialized PackedMatrix template methods - */ -#include "matrix/cblas-wrappers.h" -#include "matrix/packed-matrix.h" -#include "matrix/kaldi-vector.h" - -namespace kaldi { - -template -void PackedMatrix::Scale(Real alpha) { - size_t nr = num_rows_, - sz = (nr * (nr + 1)) / 2; - cblas_Xscal(sz, alpha, data_, 1); -} - -template -void PackedMatrix::AddPacked(const Real alpha, const PackedMatrix &rMa) { - KALDI_ASSERT(num_rows_ == rMa.NumRows()); - size_t nr = num_rows_, - sz = (nr * (nr + 1)) / 2; - cblas_Xaxpy(sz, alpha, rMa.Data(), 1, data_, 1); -} - -template -void PackedMatrix::SetRandn() { - Real *data = data_; - size_t dim = num_rows_, size = ((dim*(dim+1))/2); - for (size_t i = 0; i < size; i++) - data[i] = RandGauss(); -} - -template -inline void PackedMatrix::Init(MatrixIndexT r) { - if (r == 0) { - num_rows_ = 0; - data_ = 0; - return; - } - size_t size = ((static_cast(r) * static_cast(r + 1)) / 2); - - if (static_cast(static_cast(size)) != size) { - KALDI_WARN << "Allocating packed matrix whose full dimension does not fit " - << "in MatrixIndexT: not all code is tested for this case."; - } - - void *data; // aligned memory block - void *temp; - - if ((data = KALDI_MEMALIGN(16, size * sizeof(Real), &temp)) != NULL) { - this->data_ = static_cast (data); - this->num_rows_ = r; - } else { - throw std::bad_alloc(); - } -} - -template -void PackedMatrix::Swap(PackedMatrix *other) { - std::swap(data_, other->data_); - std::swap(num_rows_, other->num_rows_); -} - -template -void PackedMatrix::Swap(Matrix *other) { - std::swap(data_, other->data_); - std::swap(num_rows_, other->num_rows_); -} - - -template -void PackedMatrix::Resize(MatrixIndexT r, MatrixResizeType resize_type) { - // the next block uses recursion to handle what we have to do if - // resize_type == kCopyData. - if (resize_type == kCopyData) { - if (this->data_ == NULL || r == 0) resize_type = kSetZero; // nothing to copy. - else if (this->num_rows_ == r) { return; } // nothing to do. - else { - // set tmp to a packed matrix of the desired size. - PackedMatrix tmp(r, kUndefined); - size_t r_min = std::min(r, num_rows_); - size_t mem_size_min = sizeof(Real) * (r_min*(r_min+1))/2, - mem_size_full = sizeof(Real) * (r*(r+1))/2; - // Copy the contents to tmp. - memcpy(tmp.data_, data_, mem_size_min); - char *ptr = static_cast(static_cast(tmp.data_)); - // Set the rest of the contents of tmp to zero. - memset(static_cast(ptr + mem_size_min), 0, mem_size_full-mem_size_min); - tmp.Swap(this); - return; - } - } - if (data_ != NULL) Destroy(); - Init(r); - if (resize_type == kSetZero) SetZero(); -} - - - -template -void PackedMatrix::AddToDiag(Real r) { - Real *ptr = data_; - for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { - *ptr += r; - ptr += i; - } -} - -template -void PackedMatrix::ScaleDiag(Real alpha) { - Real *ptr = data_; - for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { - *ptr *= alpha; - ptr += i; - } -} - -template -void PackedMatrix::SetDiag(Real alpha) { - Real *ptr = data_; - for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { - *ptr = alpha; - ptr += i; - } -} - - - -template -template -void PackedMatrix::CopyFromPacked(const PackedMatrix &orig) { - KALDI_ASSERT(NumRows() == orig.NumRows()); - if (sizeof(Real) == sizeof(OtherReal)) { - memcpy(data_, orig.Data(), SizeInBytes()); - } else { - Real *dst = data_; - const OtherReal *src = orig.Data(); - size_t nr = NumRows(), - size = (nr * (nr + 1)) / 2; - for (size_t i = 0; i < size; i++, dst++, src++) - *dst = *src; - } -} - -// template instantiations. -template -void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); -template -void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); -template -void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); -template -void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); - - - -template -template -void PackedMatrix::CopyFromVec(const SubVector &vec) { - MatrixIndexT size = (NumRows()*(NumRows()+1)) / 2; - KALDI_ASSERT(vec.Dim() == size); - if (sizeof(Real) == sizeof(OtherReal)) { - memcpy(data_, vec.Data(), size * sizeof(Real)); - } else { - Real *dst = data_; - const OtherReal *src = vec.Data(); - for (MatrixIndexT i = 0; i < size; i++, dst++, src++) - *dst = *src; - } -} - -// template instantiations. -template -void PackedMatrix::CopyFromVec(const SubVector &orig); -template -void PackedMatrix::CopyFromVec(const SubVector &orig); -template -void PackedMatrix::CopyFromVec(const SubVector &orig); -template -void PackedMatrix::CopyFromVec(const SubVector &orig); - - - -template -void PackedMatrix::SetZero() { - memset(data_, 0, SizeInBytes()); -} - -template -void PackedMatrix::SetUnit() { - memset(data_, 0, SizeInBytes()); - for (MatrixIndexT row = 0;row < num_rows_;row++) - (*this)(row, row) = 1.0; -} - -template -Real PackedMatrix::Trace() const { - Real ans = 0.0; - for (MatrixIndexT row = 0;row < num_rows_;row++) - ans += (*this)(row, row); - return ans; -} - -template -void PackedMatrix::Destroy() { - // we need to free the data block if it was defined - if (data_ != NULL) KALDI_MEMALIGN_FREE(data_); - data_ = NULL; - num_rows_ = 0; -} - - -template -void PackedMatrix::Write(std::ostream &os, bool binary) const { - if (!os.good()) { - KALDI_ERR << "Failed to write vector to stream: stream not good"; - } - - int32 size = this->NumRows(); // make the size 32-bit on disk. - KALDI_ASSERT(this->NumRows() == (MatrixIndexT) size); - MatrixIndexT num_elems = ((size+1)*(MatrixIndexT)size)/2; - - if(binary) { - std::string my_token = (sizeof(Real) == 4 ? "FP" : "DP"); - WriteToken(os, binary, my_token); - WriteBasicType(os, binary, size); - // We don't use the built-in Kaldi write routines for the floats, as they are - // not efficient enough. - os.write((const char*) data_, sizeof(Real) * num_elems); - } - else { - if(size == 0) - os<<"[ ]\n"; - else { - os<<"[\n"; - MatrixIndexT i = 0; - for (int32 j = 0; j < size; j++) { - for (int32 k = 0; k < j + 1; k++) { - WriteBasicType(os, binary, data_[i++]); - } - os << ( (j==size-1)? "]\n" : "\n"); - } - KALDI_ASSERT(i == num_elems); - } - } - if (os.fail()) { - KALDI_ERR << "Failed to write packed matrix to stream"; - } -} - -// template -// void Save (std::ostream & os, const PackedMatrix& rM) -// { -// const Real* p_elem = rM.data(); -// for (MatrixIndexT i = 0; i < rM.NumRows(); i++) { -// for (MatrixIndexT j = 0; j <= i ; j++) { -// os << *p_elem; -// p_elem++; -// if (j == i) { -// os << '\n'; -// } -// else { -// os << ' '; -// } -// } -// } -// if (os.fail()) -// KALDI_ERR("Failed to write packed matrix to stream"); -// } - - - - - -template -void PackedMatrix::Read(std::istream& is, bool binary, bool add) { - if (add) { - PackedMatrix tmp; - tmp.Read(is, binary, false); // read without adding. - if (this->NumRows() == 0) this->Resize(tmp.NumRows()); - else { - if (this->NumRows() != tmp.NumRows()) { - if (tmp.NumRows() == 0) return; // do nothing in this case. - else KALDI_ERR << "PackedMatrix::Read, size mismatch " << this->NumRows() - << " vs. " << tmp.NumRows(); - } - } - this->AddPacked(1.0, tmp); - return; - } // now assume add == false. - - std::ostringstream specific_error; - MatrixIndexT pos_at_start = is.tellg(); - int peekval = Peek(is, binary); - const char *my_token = (sizeof(Real) == 4 ? "FP" : "DP"); - const char *new_format_token = "["; - bool is_new_format = false;//added by hxu - char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); - int32 size; - MatrixIndexT num_elems; - - if (peekval == other_token_start) { // need to instantiate the other type to read it. - typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. - PackedMatrix other(this->NumRows()); - other.Read(is, binary, false); // add is false at this point. - this->Resize(other.NumRows()); - this->CopyFromPacked(other); - return; - } - std::string token; - ReadToken(is, binary, &token); - if (token != my_token) { - if(token != new_format_token) { - specific_error << ": Expected token " << my_token << ", got " << token; - goto bad; - } - //new format it is - is_new_format = true; - } - if(!is_new_format) { - ReadBasicType(is, binary, &size); // throws on error. - if ((MatrixIndexT)size != this->NumRows()) { - KALDI_ASSERT(size>=0); - this->Resize(size); - } - num_elems = ((size+1)*(MatrixIndexT)size)/2; - if (!binary) { - for (MatrixIndexT i = 0; i < num_elems; i++) { - ReadBasicType(is, false, data_+i); // will throw on error. - } - } else { - if (num_elems) - is.read(reinterpret_cast(data_), sizeof(Real)*num_elems); - } - if (is.fail()) goto bad; - return; - } - else { - std::vector data; - while(1) { - int32 num_lines = 0; - int i = is.peek(); - if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto bad; } - else if (static_cast(i) == ']') { // Finished reading matrix. - is.get(); // eat the "]". - i = is.peek(); - if (static_cast(i) == '\r') { - is.get(); - is.get(); // get \r\n (must eat what we wrote) - }// I don't actually understand what it's doing here - else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) - - if (is.fail()) { - KALDI_WARN << "After end of matrix data, read error."; - // we got the data we needed, so just warn for this error. - } - //now process the data: - num_lines = int32(sqrt(data.size()*2)); - - KALDI_ASSERT(data.size() == num_lines*(num_lines+1)/2); - - this->Resize(num_lines); - - //std::cout<= '0' && i <= '9') || i == '-' ) { // A number... - Real r; - is >> r; - if (is.fail()) { - specific_error << "Stream failure/EOF while reading matrix data."; - goto bad; - } - data.push_back(r); - } - else if (isspace(i)) { - is.get(); // eat the space and do nothing. - } else { // NaN or inf or error. - std::string str; - is >> str; - if (!KALDI_STRCASECMP(str.c_str(), "inf") || - !KALDI_STRCASECMP(str.c_str(), "infinity")) { - data.push_back(std::numeric_limits::infinity()); - KALDI_WARN << "Reading infinite value into matrix."; - } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { - data.push_back(std::numeric_limits::quiet_NaN()); - KALDI_WARN << "Reading NaN value into matrix."; - } else { - specific_error << "Expecting numeric matrix data, got " << str; - goto bad; - } - } - } - } -bad: - KALDI_ERR << "Failed to read packed matrix from stream. " << specific_error.str() - << " File position at start is " - << pos_at_start << ", currently " << is.tellg(); -} - - -// Instantiate PackedMatrix for float and double. -template -class PackedMatrix; - -template -class PackedMatrix; - - -} // namespace kaldi - diff --git a/speechx/speechx/kaldi/matrix/packed-matrix.h b/speechx/speechx/kaldi/matrix/packed-matrix.h deleted file mode 100644 index 722d932b..00000000 --- a/speechx/speechx/kaldi/matrix/packed-matrix.h +++ /dev/null @@ -1,197 +0,0 @@ -// matrix/packed-matrix.h - -// Copyright 2009-2013 Ondrej Glembek; Lukas Burget; Microsoft Corporation; -// Saarland University; Yanmin Qian; -// Johns Hopkins University (Author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_MATRIX_PACKED_MATRIX_H_ -#define KALDI_MATRIX_PACKED_MATRIX_H_ - -#include "matrix/matrix-common.h" -#include - -namespace kaldi { - -/// \addtogroup matrix_funcs_io -// we need to declare the friend << operator here -template -std::ostream & operator <<(std::ostream & out, const PackedMatrix& M); - - -/// \addtogroup matrix_group -/// @{ - -/// @brief Packed matrix: base class for triangular and symmetric matrices. -template class PackedMatrix { - friend class CuPackedMatrix; - public: - //friend class CuPackedMatrix; - - PackedMatrix() : data_(NULL), num_rows_(0) {} - - explicit PackedMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero): - data_(NULL) { Resize(r, resize_type); } - - explicit PackedMatrix(const PackedMatrix &orig) : data_(NULL) { - Resize(orig.num_rows_, kUndefined); - CopyFromPacked(orig); - } - - template - explicit PackedMatrix(const PackedMatrix &orig) : data_(NULL) { - Resize(orig.NumRows(), kUndefined); - CopyFromPacked(orig); - } - - void SetZero(); /// < Set to zero - void SetUnit(); /// < Set to unit matrix. - void SetRandn(); /// < Set to random values of a normal distribution - - Real Trace() const; - - // Needed for inclusion in std::vector - PackedMatrix & operator =(const PackedMatrix &other) { - Resize(other.NumRows()); - CopyFromPacked(other); - return *this; - } - - ~PackedMatrix() { - Destroy(); - } - - /// Set packed matrix to a specified size (can be zero). - /// The value of the new data depends on resize_type: - /// -if kSetZero, the new data will be zero - /// -if kUndefined, the new data will be undefined - /// -if kCopyData, the new data will be the same as the old data in any - /// shared positions, and zero elsewhere. - /// This function takes time proportional to the number of data elements. - void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero); - - void AddToDiag(const Real r); // Adds r to diaginal - - void ScaleDiag(const Real alpha); // Scales diagonal by alpha. - - void SetDiag(const Real alpha); // Sets diagonal to this value. - - template - void CopyFromPacked(const PackedMatrix &orig); - - /// CopyFromVec just interprets the vector as having the same layout - /// as the packed matrix. Must have the same dimension, i.e. - /// orig.Dim() == (NumRows()*(NumRows()+1)) / 2; - template - void CopyFromVec(const SubVector &orig); - - Real* Data() { return data_; } - const Real* Data() const { return data_; } - inline MatrixIndexT NumRows() const { return num_rows_; } - inline MatrixIndexT NumCols() const { return num_rows_; } - size_t SizeInBytes() const { - size_t nr = static_cast(num_rows_); - return ((nr * (nr+1)) / 2) * sizeof(Real); - } - - //MatrixIndexT Stride() const { return stride_; } - - // This code is duplicated in child classes to avoid extra levels of calls. - Real operator() (MatrixIndexT r, MatrixIndexT c) const { - KALDI_ASSERT(static_cast(r) < - static_cast(num_rows_) && - static_cast(c) < - static_cast(num_rows_) - && c <= r); - return *(data_ + (r * (r + 1)) / 2 + c); - } - - // This code is duplicated in child classes to avoid extra levels of calls. - Real &operator() (MatrixIndexT r, MatrixIndexT c) { - KALDI_ASSERT(static_cast(r) < - static_cast(num_rows_) && - static_cast(c) < - static_cast(num_rows_) - && c <= r); - return *(data_ + (r * (r + 1)) / 2 + c); - } - - Real Max() const { - KALDI_ASSERT(num_rows_ > 0); - return * (std::max_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); - } - - Real Min() const { - KALDI_ASSERT(num_rows_ > 0); - return * (std::min_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); - } - - void Scale(Real c); - - friend std::ostream & operator << <> (std::ostream & out, - const PackedMatrix &m); - // Use instead of stream<<*this, if you want to add to existing contents. - // Will throw exception on failure. - void Read(std::istream &in, bool binary, bool add = false); - - void Write(std::ostream &out, bool binary) const; - - void Destroy(); - - /// Swaps the contents of *this and *other. Shallow swap. - void Swap(PackedMatrix *other); - void Swap(Matrix *other); - - - protected: - // Will only be called from this class or derived classes. - void AddPacked(const Real alpha, const PackedMatrix& M); - Real *data_; - MatrixIndexT num_rows_; - //MatrixIndexT stride_; - private: - /// Init assumes the current contents of the class are is invalid (i.e. junk or - /// has already been freed), and it sets the matrixd to newly allocated memory - /// with the specified dimension. dim == 0 is acceptable. The memory contents - /// pointed to by data_ will be undefined. - void Init(MatrixIndexT dim); - -}; -/// @} end "addtogroup matrix_group" - - -/// \addtogroup matrix_funcs_io -/// @{ - -template -std::ostream & operator << (std::ostream & os, const PackedMatrix& M) { - M.Write(os, false); - return os; -} - -template -std::istream & operator >> (std::istream &is, PackedMatrix &M) { - M.Read(is, false); - return is; -} - -/// @} - -} // namespace kaldi - -#endif - diff --git a/speechx/speechx/kaldi/matrix/qr.cc b/speechx/speechx/kaldi/matrix/qr.cc deleted file mode 100644 index 861dead0..00000000 --- a/speechx/speechx/kaldi/matrix/qr.cc +++ /dev/null @@ -1,580 +0,0 @@ -// matrix/qr.cc - -// Copyright 2012 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "matrix/sp-matrix.h" -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/matrix-functions.h" -#include "matrix/cblas-wrappers.h" - -// This file contains an implementation of the Symmetric QR Algorithm -// for the symmetric eigenvalue problem. See Golub and Van Loan, -// 3rd ed., Algorithm 8.3.3. - -namespace kaldi { - - -/* This is from Golub and Van Loan 3rd ed., sec. 5.1.3, - p210. - x is the input of dimenson 'dim', v is the output of dimension - dim, and beta is a scalar. Note: we use zero-based - not one-based indexing. */ -/* -// We are commenting out the function below ("House") because it's not -// needed, but we keep it just to show how we came up with HouseBackward. -template -void House(MatrixIndexT dim, const Real *x, Real *v, Real *beta) { - KALDI_ASSERT(dim > 0); - // To avoid overflow, we first compute the max of x_ (or - // one if that's zero, and we'll replace "x" by x/max(x_i) - // below. The householder vector is anyway invariant to - // the magnitude of x. We could actually avoid this extra loop - // over x if we wanted to be a bit smarter, but anyway this - // doesn't dominate the O(N) performance of the algorithm. - Real s; // s is a scale on x. - { - Real max_x = std::numeric_limits::min(); - for (MatrixIndexT i = 0; i < dim; i++) - max_x = std::max(max_x, (x[i] < 0 ? -x[i] : x[i])); - if (max_x == 0.0) max_x = 1.0; - s = 1.0 / max_x; - } - - Real sigma = 0.0; - v[0] = 1.0; - for (MatrixIndexT i = 1; i < dim; i++) { - sigma += (x[i]*s) * (x[i]*s); - v[i] = x[i]*s; - } - if (sigma == 0.0) *beta = 0.0; - else { - // When we say x1 = x[0], we reference the one-based indexing - // in Golub and Van Loan. - Real x1 = x[0] * s, mu = std::sqrt(x1*x1 + sigma); - if (x1 <= 0) { - v[0] = x1 - mu; - } else { - v[0] = -sigma / (x1 + mu); - KALDI_ASSERT(KALDI_ISFINITE(v[dim-1])); - } - Real v1 = v[0]; - Real v1sq = v1 * v1; - *beta = 2 * v1sq / (sigma + v1sq); - Real inv_v1 = 1.0 / v1; - if (KALDI_ISINF(inv_v1)) { - // can happen if v1 is denormal. - KALDI_ASSERT(v1 == v1 && v1 != 0.0); - for (MatrixIndexT i = 0; i < dim; i++) v[i] /= v1; - } else { - cblas_Xscal(dim, inv_v1, v, 1); - } - if (KALDI_ISNAN(inv_v1)) { - KALDI_ERR << "NaN encountered in HouseBackward"; - } - } -} -*/ - -// This is a backward version of the "House" routine above: -// backward because it's the last index, not the first index of -// the vector that is "special". This is convenient in -// the Tridiagonalize routine that uses reversed indexes for -// compatibility with the packed lower triangular format. -template -void HouseBackward(MatrixIndexT dim, const Real *x, Real *v, Real *beta) { - KALDI_ASSERT(dim > 0); - // To avoid overflow, we first compute the max of x_ (or - // one if that's zero, and we'll replace "x" by x/max(x_i) - // below. The householder vector is anyway invariant to - // the magnitude of x. We could actually avoid this extra loop - // over x if we wanted to be a bit smarter, but anyway this - // doesn't dominate the O(N) performance of the algorithm. - Real s; // s is a scale on x. - { - Real max_x = std::numeric_limits::min(); - for (MatrixIndexT i = 0; i < dim; i++) - max_x = std::max(max_x, (x[i] < 0 ? -x[i] : x[i])); - s = 1.0 / max_x; - } - Real sigma = 0.0; - v[dim-1] = 1.0; - for (MatrixIndexT i = 0; i + 1 < dim; i++) { - sigma += (x[i] * s) * (x[i] * s); - v[i] = x[i] * s; - } - KALDI_ASSERT(KALDI_ISFINITE(sigma) && - "Tridiagonalizing matrix that is too large or has NaNs."); - if (sigma == 0.0) *beta = 0.0; - else { - Real x1 = x[dim-1] * s, mu = std::sqrt(x1 * x1 + sigma); - if (x1 <= 0) { - v[dim-1] = x1 - mu; - } else { - v[dim-1] = -sigma / (x1 + mu); - KALDI_ASSERT(KALDI_ISFINITE(v[dim-1])); - } - Real v1 = v[dim-1]; - Real v1sq = v1 * v1; - *beta = 2 * v1sq / (sigma + v1sq); - Real inv_v1 = 1.0 / v1; - if (KALDI_ISINF(inv_v1)) { - // can happen if v1 is denormal. - KALDI_ASSERT(v1 == v1 && v1 != 0.0); - for (MatrixIndexT i = 0; i < dim; i++) v[i] /= v1; - } else { - cblas_Xscal(dim, inv_v1, v, 1); - } - if (KALDI_ISNAN(inv_v1)) { - KALDI_ERR << "NaN encountered in HouseBackward"; - } - } -} - - -/** - This routine tridiagonalizes *this. C.f. Golub and Van Loan 3rd ed., sec. - 8.3.1 (p415). We reverse the order of the indices as it's more natural - with packed lower-triangular matrices to do it this way. There's also - a shift from one-based to zero-based indexing, so the index - k is transformed k -> n - k, and a corresponding transpose... - - Let the original *this be A. This algorithms replaces *this with - a tridiagonal matrix T such that T = Q A Q^T for an orthogonal Q. - Caution: Q is transposed vs. Golub and Van Loan. - If Q != NULL it outputs Q. -*/ -template -void SpMatrix::Tridiagonalize(MatrixBase *Q) { - MatrixIndexT n = this->NumRows(); - KALDI_ASSERT(Q == NULL || (Q->NumRows() == n && - Q->NumCols() == n)); - if (Q != NULL) Q->SetUnit(); - Real *data = this->Data(); - Real *qdata = (Q == NULL ? NULL : Q->Data()); - MatrixIndexT qstride = (Q == NULL ? 0 : Q->Stride()); - Vector tmp_v(n-1), tmp_p(n); - Real beta, *v = tmp_v.Data(), *p = tmp_p.Data(), *w = p, *x = p; - for (MatrixIndexT k = n-1; k >= 2; k--) { - MatrixIndexT ksize = ((k+1)*k)/2; - // ksize is the packed size of the lower-triangular matrix of size k, - // which is the size of "all rows previous to this one." - Real *Arow = data + ksize; // In Golub+Van Loan it was A(k+1:n, k), we - // have Arow = A(k, 0:k-1). - HouseBackward(k, Arow, v, &beta); // sets v and beta. - cblas_Xspmv(k, beta, data, v, 1, 0.0, p, 1); // p = beta * A(0:k-1,0:k-1) v - Real minus_half_beta_pv = -0.5 * beta * cblas_Xdot(k, p, 1, v, 1); - cblas_Xaxpy(k, minus_half_beta_pv, v, 1, w, 1); // w = p - (beta p^T v/2) v; - // this relies on the fact that w and p are the same pointer. - // We're doing A(k, k-1) = ||Arow||. It happens that this element - // is indexed at ksize + k - 1 in the packed lower-triangular format. - data[ksize + k - 1] = std::sqrt(cblas_Xdot(k, Arow, 1, Arow, 1)); - for (MatrixIndexT i = 0; i + 1 < k; i++) - data[ksize + i] = 0; // This is not in Golub and Van Loan but is - // necessary if we're not using parts of A to store the Householder - // vectors. - // We're doing A(0:k-1,0:k-1) -= (v w' + w v') - cblas_Xspr2(k, -1.0, v, 1, w, 1, data); - if (Q != NULL) { // C.f. Golub, Q is H_1 .. H_n-2... in this - // case we apply them in the opposite order so it's H_n-1 .. H_1, - // but also Q is transposed so we really have Q = H_1 .. H_n-1. - // It's a double negative. - // Anyway, we left-multiply Q by each one. The H_n would each be - // diag(I + beta v v', I) but we don't ever touch the last dims. - // We do (in Matlab notation): - // Q(0:k-1,:) = (I - beta v v') * Q, i.e.: - // Q(:,0:i-1) += -beta v (v' Q(:,0:k-1)v .. let x = -beta Q(0:k-1,:)^T v. - cblas_Xgemv(kTrans, k, n, -beta, qdata, qstride, v, 1, 0.0, x, 1); - // now x = -beta Q(:,0:k-1) v. - // The next line does: Q(:,0:k-1) += v x'. - cblas_Xger(k, n, 1.0, v, 1, x, 1, qdata, qstride); - } - } -} - -// Instantiate these functions, as it wasn't implemented in sp-matrix.cc -// where we instantiated the whole class. -template -void SpMatrix::Tridiagonalize(MatrixBase *Q); -template -void SpMatrix::Tridiagonalize(MatrixBase *Q); - -/// Create Givens rotations, as in Golub and Van Loan 3rd ed., page 216. -template -inline void Givens(Real a, Real b, Real *c, Real *s) { - if (b == 0) { - *c = 1; - *s = 0; - } else { - if (std::abs(b) > std::abs(a)) { - Real tau = -a / b; - *s = 1 / std::sqrt(1 + tau*tau); - *c = *s * tau; - } else { - Real tau = -b / a; - *c = 1 / std::sqrt(1 + tau*tau); - *s = *c * tau; - } - } -} - - -// Some internal code for the QR algorithm: one "QR step". -// This is Golub and Van Loan 3rd ed., Algorithm 8.3.2 "Implicit Symmetric QR step -// with Wilkinson shift." A couple of differences: this code is -// in zero based arithmetic, and we represent Q transposed from -// their Q for memory locality with row-major-indexed matrices. -template -void QrStep(MatrixIndexT n, - Real *diag, - Real *off_diag, - MatrixBase *Q) { - KALDI_ASSERT(n >= 2); - // below, "scale" could be any number; we introduce it to keep the - // floating point quantities within a good range. - Real d = (diag[n-2] - diag[n-1]) / 2.0, - t = off_diag[n-2], - inv_scale = std::max(std::max(std::abs(d), std::abs(t)), - std::numeric_limits::min()), - scale = 1.0 / inv_scale, - d_scaled = d * scale, - off_diag_n2_scaled = off_diag[n-2] * scale, - t2_n_n1_scaled = off_diag_n2_scaled * off_diag_n2_scaled, - sgn_d = (d > 0.0 ? 1.0 : -1.0), - mu = diag[n-1] - inv_scale * t2_n_n1_scaled / - (d_scaled + sgn_d * std::sqrt(d_scaled * d_scaled + t2_n_n1_scaled)), - x = diag[0] - mu, - z = off_diag[0]; - KALDI_ASSERT(KALDI_ISFINITE(x)); - Real *Qdata = (Q == NULL ? NULL : Q->Data()); - MatrixIndexT Qstride = (Q == NULL ? 0 : Q->Stride()), - Qcols = (Q == NULL ? 0 : Q->NumCols()); - for (MatrixIndexT k = 0; k < n-1; k++) { - Real c, s; - Givens(x, z, &c, &s); - // Rotate dimensions k and k+1 with the Givens matrix G, as - // T <== G^T T G. - // In 2d, a Givens matrix is [ c s; -s c ]. Forget about - // the dimension-indexing issues and assume we have a 2x2 - // symmetric matrix [ p q ; q r ] - // We ask our friends at Wolfram Alpha about - // { { c, -s}, {s, c} } * { {p, q}, {q, r} } * { { c, s}, {-s, c} } - // Interpreting the result as [ p', q' ; q', r ] - // p' = c (c p - s q) - s (c q - s r) - // q' = s (c p - s q) + c (c q - s r) - // r' = s (s p + c q) + c (s q + c r) - Real p = diag[k], q = off_diag[k], r = diag[k+1]; - // p is element k,k; r is element k+1,k+1; q is element k,k+1 or k+1,k. - // We'll let the compiler optimize this. - diag[k] = c * (c*p - s*q) - s * (c*q - s*r); - off_diag[k] = s * (c*p - s*q) + c * (c*q - s*r); - diag[k+1] = s * (s*p + c*q) + c * (s*q + c*r); - - // We also have some other elements to think of that - // got rotated in a simpler way: if k>0, - // then element (k, k-1) and (k+1, k-1) get rotated. Here, - // element k+1, k-1 will be present as z; it's the out-of-band - // element that we remembered from last time. This is - // on the left as it's the row indexes that differ, so think of - // this as being premultiplied by G^T. In fact we're multiplying - // T by in some sense the opposite/transpose of the Givens rotation. - if (k > 0) { // Note, in rotations, going backward, (x,y) -> ((cx - sy), (sx + cy)) - Real &elem_k_km1 = off_diag[k-1], - elem_kp1_km1 = z; // , tmp = elem_k_km1; - elem_k_km1 = c*elem_k_km1 - s*elem_kp1_km1; - // The next line will set elem_kp1_km1 to zero and we'll never access this - // value, so we comment it out. - // elem_kp1_km1 = s*tmp + c*elem_kp1_km1; - } - if (Q != NULL) - cblas_Xrot(Qcols, Qdata + k*Qstride, 1, - Qdata + (k+1)*Qstride, 1, c, -s); - if (k < n-2) { - // Next is the elements (k+2, k) and (k+2, k-1), to be rotated, again - // backwards. - Real &elem_kp2_k = z, - &elem_kp2_kp1 = off_diag[k+1]; - // Note: elem_kp2_k == z would start off as zero because it's - // two off the diagonal, and not been touched yet. Therefore - // we eliminate it in expressions below, commenting it out. - // If we didn't do this we should set it to zero first. - elem_kp2_k = - s * elem_kp2_kp1; // + c*elem_kp2_k - elem_kp2_kp1 = c * elem_kp2_kp1; // + s*elem_kp2_k (original value). - // The next part is from the algorithm they describe: x = t_{k+1,k} - x = off_diag[k]; - } - } -} - - -// Internal code for the QR algorithm, where the diagonal -// and off-diagonal of the symmetric matrix are represented as -// vectors of length n and n-1. -template -void QrInternal(MatrixIndexT n, - Real *diag, - Real *off_diag, - MatrixBase *Q) { - KALDI_ASSERT(Q == NULL || Q->NumCols() == n); // We may - // later relax the condition that Q->NumCols() == n. - - MatrixIndexT counter = 0, max_iters = 500 + 4*n, // Should never take this many iters. - large_iters = 100 + 2*n; - Real epsilon = (pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); - - for (; counter < max_iters; counter++) { // this takes the place of "until - // q=n"... we'll break out of the - // loop when we converge. - if (counter == large_iters || - (counter > large_iters && (counter - large_iters) % 50 == 0)) { - KALDI_WARN << "Took " << counter - << " iterations in QR (dim is " << n << "), doubling epsilon."; - SubVector d(diag, n), o(off_diag, n-1); - KALDI_WARN << "Diag, off-diag are " << d << " and " << o; - epsilon *= 2.0; - } - for (MatrixIndexT i = 0; i+1 < n; i++) { - if (std::abs(off_diag[i]) <= epsilon * - (std::abs(diag[i]) + std::abs(diag[i+1]))) - off_diag[i] = 0.0; - } - // The next code works out p, q, and npq which is n - p - q. - // For the definitions of q and p, see Golub and Van Loan; we - // partition the n dims into pieces of size (p, n-p-q, q) where - // the part of size q is diagonal and the part of size n-p-p is - // "unreduced", i.e. has no zero off-diagonal elements. - MatrixIndexT q = 0; - // Note: below, "n-q < 2" should more clearly be "n-2-q < 0", but that - // causes problems if MatrixIndexT is unsigned. - while (q < n && (n-q < 2 || off_diag[n-2-q] == 0.0)) - q++; - if (q == n) break; // we're done. It's diagonal. - KALDI_ASSERT(n - q >= 2); - MatrixIndexT npq = 2; // Value of n - p - q, where n - p - q must be - // unreduced. This is the size of "middle" band of elements. If q != n, - // we must have hit a nonzero off-diag element, so the size of this - // band must be at least two. - while (npq + q < n && (n-q-npq-1 < 0 || off_diag[n-q-npq-1] != 0.0)) - npq++; - MatrixIndexT p = n - q - npq; - { // Checks. - for (MatrixIndexT i = 0; i+1 < npq; i++) - KALDI_ASSERT(off_diag[p + i] != 0.0); - for (MatrixIndexT i = 0; i+1 < q; i++) - KALDI_ASSERT(off_diag[p + npq - 1 + i] == 0.0); - if (p > 1) // Something must have stopped npq from growing further.. - KALDI_ASSERT(off_diag[p-1] == 0.0); // so last off-diag elem in - // group of size p must be zero. - } - - if (Q != NULL) { - // Do one QR step on the middle part of Q only. - // Qpart will be a subset of the rows of Q. - SubMatrix Qpart(*Q, p, npq, 0, Q->NumCols()); - QrStep(npq, diag + p, off_diag + p, &Qpart); - } else { - QrStep(npq, diag + p, off_diag + p, - static_cast*>(NULL)); - } - } - if (counter == max_iters) { - KALDI_WARN << "Failure to converge in QR algorithm. " - << "Exiting with partial output."; - } -} - - -/** - This is the symmetric QR algorithm, from Golub and Van Loan 3rd ed., Algorithm - 8.3.3. Q is transposed w.r.t. there, though. -*/ -template -void SpMatrix::Qr(MatrixBase *Q) { - KALDI_ASSERT(this->IsTridiagonal()); - // We envisage that Q would be square but we don't check for this, - // as there are situations where you might not want this. - KALDI_ASSERT(Q == NULL || Q->NumRows() == this->NumRows()); - // Note: the first couple of lines of the algorithm they give would be done - // outside of this function, by calling Tridiagonalize(). - - MatrixIndexT n = this->NumRows(); - Vector diag(n), off_diag(n-1); - for (MatrixIndexT i = 0; i < n; i++) { - diag(i) = (*this)(i, i); - if (i > 0) off_diag(i-1) = (*this)(i, i-1); - } - QrInternal(n, diag.Data(), off_diag.Data(), Q); - // Now set *this to the value represented by diag and off_diag. - this->SetZero(); - for (MatrixIndexT i = 0; i < n; i++) { - (*this)(i, i) = diag(i); - if (i > 0) (*this)(i, i-1) = off_diag(i-1); - } -} - -template -void SpMatrix::Eig(VectorBase *s, MatrixBase *P) const { - MatrixIndexT dim = this->NumRows(); - KALDI_ASSERT(s->Dim() == dim); - KALDI_ASSERT(P == NULL || (P->NumRows() == dim && P->NumCols() == dim)); - - SpMatrix A(*this); // Copy *this, since the tridiagonalization - // and QR decomposition are destructive. - // Note: for efficiency of memory access, the tridiagonalization - // algorithm makes the *rows* of P the eigenvectors, not the columns. - // We'll transpose P before we exit. - // Also note: P may be null if you don't want the eigenvectors. This - // will make this function more efficient. - - A.Tridiagonalize(P); // Tridiagonalizes. - A.Qr(P); // Diagonalizes. - if(P) P->Transpose(); - s->CopyDiagFromPacked(A); -} - - -template -void SpMatrix::TopEigs(VectorBase *s, MatrixBase *P, - MatrixIndexT lanczos_dim) const { - const SpMatrix &S(*this); // call this "S" for easy notation. - MatrixIndexT eig_dim = s->Dim(); // Space of dim we want to retain. - if (lanczos_dim <= 0) - lanczos_dim = std::max(eig_dim + 50, eig_dim + eig_dim/2); - MatrixIndexT dim = this->NumRows(); - if (lanczos_dim >= dim) { - // There would be no speed advantage in using this method, so just - // use the regular approach. - Vector s_tmp(dim); - Matrix P_tmp(dim, dim); - this->Eig(&s_tmp, &P_tmp); - SortSvd(&s_tmp, &P_tmp); - s->CopyFromVec(s_tmp.Range(0, eig_dim)); - P->CopyFromMat(P_tmp.Range(0, dim, 0, eig_dim)); - return; - } - KALDI_ASSERT(eig_dim <= dim && eig_dim > 0); - KALDI_ASSERT(P->NumRows() == dim && P->NumCols() == eig_dim); // each column - // is one eigenvector. - - Matrix Q(lanczos_dim, dim); // The rows of Q will be the - // orthogonal vectors of the Krylov subspace. - - SpMatrix T(lanczos_dim); // This will be equal to Q S Q^T, - // i.e. *this projected into the Krylov subspace. Note: only the - // diagonal and off-diagonal fo T are nonzero, i.e. it's tridiagonal, - // but we don't have access to the low-level algorithms that work - // on that type of matrix (since we want to use ATLAS). So we just - // do normal SVD, on a full matrix; it won't typically dominate. - - Q.Row(0).SetRandn(); - Q.Row(0).Scale(1.0 / Q.Row(0).Norm(2)); - for (MatrixIndexT d = 0; d < lanczos_dim; d++) { - Vector r(dim); - r.AddSpVec(1.0, S, Q.Row(d), 0.0); - // r = S * q_d - MatrixIndexT counter = 0; - Real end_prod; - while (1) { // Normally we'll do this loop only once: - // we repeat to handle cases where r gets very much smaller - // and we want to orthogonalize again. - // We do "full orthogonalization" to preserve stability, - // even though this is usually a waste of time. - Real start_prod = VecVec(r, r); - for (SignedMatrixIndexT e = d; e >= 0; e--) { // e must be signed! - SubVector q_e(Q, e); - Real prod = VecVec(r, q_e); - if (counter == 0 && static_cast(e) + 1 >= d) // Keep T tridiagonal, which - T(d, e) = prod; // mathematically speaking, it is. - r.AddVec(-prod, q_e); // Subtract component in q_e. - } - if (d+1 == lanczos_dim) break; - end_prod = VecVec(r, r); - if (end_prod <= 0.1 * start_prod) { - // also handles case where both are 0. - // We're not confident any more that it's completely - // orthogonal to the rest so we want to re-do. - if (end_prod == 0.0) - r.SetRandn(); // "Restarting". - counter++; - if (counter > 100) - KALDI_ERR << "Loop detected in Lanczos iteration."; - } else { - break; - } - } - if (d+1 != lanczos_dim) { - // OK, at this point we're satisfied that r is orthogonal - // to all previous rows. - KALDI_ASSERT(end_prod != 0.0); // should have looped. - r.Scale(1.0 / std::sqrt(end_prod)); // make it unit. - Q.Row(d+1).CopyFromVec(r); - } - } - - Matrix R(lanczos_dim, lanczos_dim); - R.SetUnit(); - T.Qr(&R); // Diagonalizes T. - Vector s_tmp(lanczos_dim); - s_tmp.CopyDiagFromSp(T); - - // Now T = R * diag(s_tmp) * R^T. - // The next call sorts the elements of s from greatest to least absolute value, - // and moves around the rows of R in the corresponding way. This picks out - // the largest (absolute) eigenvalues. - SortSvd(&s_tmp, static_cast*>(NULL), &R); - // Keep only the initial rows of R, those corresponding to greatest (absolute) - // eigenvalues. - SubMatrix Rsub(R, 0, eig_dim, 0, lanczos_dim); - SubVector s_sub(s_tmp, 0, eig_dim); - s->CopyFromVec(s_sub); - - // For working out what to do now, just assume the other eigenvalues were - // zero. This is just for purposes of knowing how to get the result, and - // not getting things wrongly transposed. - // We have T = Rsub^T * diag(s_sub) * Rsub. - // Now, T = Q S Q^T, with Q orthogonal, so S = Q^T T Q = Q^T Rsub^T * diag(s) * Rsub * Q. - // The output is P and we want S = P * diag(s) * P^T, so we need P = Q^T Rsub^T. - P->AddMatMat(1.0, Q, kTrans, Rsub, kTrans, 0.0); -} - - -// Instantiate the templates for Eig and TopEig. -template -void SpMatrix::Eig(VectorBase*, MatrixBase*) const; -template -void SpMatrix::Eig(VectorBase*, MatrixBase*) const; - -template -void SpMatrix::TopEigs(VectorBase*, MatrixBase*, MatrixIndexT) const; -template -void SpMatrix::TopEigs(VectorBase*, MatrixBase*, MatrixIndexT) const; - -// Someone had a problem with the Intel compiler with -O3, with Qr not being -// defined for some strange reason (should automatically happen when -// we instantiate Eig and TopEigs), so we explicitly instantiate it here. -template -void SpMatrix::Qr(MatrixBase *Q); -template -void SpMatrix::Qr(MatrixBase *Q); - - - -} -// namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sp-matrix-inl.h b/speechx/speechx/kaldi/matrix/sp-matrix-inl.h deleted file mode 100644 index 15795923..00000000 --- a/speechx/speechx/kaldi/matrix/sp-matrix-inl.h +++ /dev/null @@ -1,42 +0,0 @@ -// matrix/sp-matrix-inl.h - -// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_MATRIX_SP_MATRIX_INL_H_ -#define KALDI_MATRIX_SP_MATRIX_INL_H_ - -#include "matrix/tp-matrix.h" - -namespace kaldi { - -// All the lines in this file seem to be declaring template specializations. -// These tell the compiler that we'll implement the templated function -// separately for the different template arguments (float, double). - -template<> -double SolveQuadraticProblem(const SpMatrix &H, const VectorBase &g, - const SolverOptions &opts, VectorBase *x); - -template<> -float SolveQuadraticProblem(const SpMatrix &H, const VectorBase &g, - const SolverOptions &opts, VectorBase *x); - -} // namespace kaldi - - -#endif // KALDI_MATRIX_SP_MATRIX_INL_H_ diff --git a/speechx/speechx/kaldi/matrix/sp-matrix.cc b/speechx/speechx/kaldi/matrix/sp-matrix.cc deleted file mode 100644 index 224ef39f..00000000 --- a/speechx/speechx/kaldi/matrix/sp-matrix.cc +++ /dev/null @@ -1,1216 +0,0 @@ -// matrix/sp-matrix.cc - -// Copyright 2009-2011 Lukas Burget; Ondrej Glembek; Microsoft Corporation -// Saarland University; Petr Schwarz; Yanmin Qian; -// Haihua Xu - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "matrix/sp-matrix.h" -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/matrix-functions.h" -#include "matrix/cblas-wrappers.h" - -namespace kaldi { - -// **************************************************************************** -// Returns the log-determinant if +ve definite, else KALDI_ERR. -// **************************************************************************** -template -Real SpMatrix::LogPosDefDet() const { - TpMatrix chol(this->NumRows()); - double det = 0.0; - double diag; - chol.Cholesky(*this); // Will throw exception if not +ve definite! - - for (MatrixIndexT i = 0; i < this->NumRows(); i++) { - diag = static_cast(chol(i, i)); - det += kaldi::Log(diag); - } - return static_cast(2*det); -} - - -template -void SpMatrix::Swap(SpMatrix *other) { - std::swap(this->data_, other->data_); - std::swap(this->num_rows_, other->num_rows_); -} - -template -void SpMatrix::SymPosSemiDefEig(VectorBase *s, - MatrixBase *P, - Real tolerance) const { - Eig(s, P); - Real max = s->Max(), min = s->Min(); - KALDI_ASSERT(-min <= tolerance * max); - s->ApplyFloor(0.0); -} - -template -Real SpMatrix::MaxAbsEig() const { - Vector s(this->NumRows()); - this->Eig(&s, static_cast*>(NULL)); - return std::max(s.Max(), -s.Min()); -} - -// returns true if positive definite--uses cholesky. -template -bool SpMatrix::IsPosDef() const { - MatrixIndexT D = (*this).NumRows(); - KALDI_ASSERT(D > 0); - try { - TpMatrix C(D); - C.Cholesky(*this); - for (MatrixIndexT r = 0; r < D; r++) - if (C(r, r) == 0.0) return false; - return true; - } - catch(...) { // not positive semidefinite. - return false; - } -} - -template -void SpMatrix::ApplyPow(Real power) { - if (power == 1) return; // can do nothing. - MatrixIndexT D = this->NumRows(); - KALDI_ASSERT(D > 0); - Matrix U(D, D); - Vector l(D); - (*this).SymPosSemiDefEig(&l, &U); - - Vector l_copy(l); - try { - l.ApplyPow(power * 0.5); - } - catch(...) { - KALDI_ERR << "Error taking power " << (power * 0.5) << " of vector " - << l_copy; - } - U.MulColsVec(l); - (*this).AddMat2(1.0, U, kNoTrans, 0.0); -} - -template -void SpMatrix::CopyFromMat(const MatrixBase &M, - SpCopyType copy_type) { - KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); - MatrixIndexT D = this->NumRows(); - - switch (copy_type) { - case kTakeMeanAndCheck: - { - Real good_sum = 0.0, bad_sum = 0.0; - for (MatrixIndexT i = 0; i < D; i++) { - for (MatrixIndexT j = 0; j < i; j++) { - Real a = M(i, j), b = M(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); - (*this)(i, j) = avg; - good_sum += std::abs(avg); - bad_sum += std::abs(diff); - } - good_sum += std::abs(M(i, i)); - (*this)(i, i) = M(i, i); - } - if (bad_sum > 0.01 * good_sum) { - KALDI_ERR << "SpMatrix::Copy(), source matrix is not symmetric: " - << bad_sum << ">" << good_sum; - } - break; - } - case kTakeMean: - { - for (MatrixIndexT i = 0; i < D; i++) { - for (MatrixIndexT j = 0; j < i; j++) { - (*this)(i, j) = 0.5*(M(i, j) + M(j, i)); - } - (*this)(i, i) = M(i, i); - } - break; - } - case kTakeLower: - { // making this one a bit more efficient. - const Real *src = M.Data(); - Real *dest = this->data_; - MatrixIndexT stride = M.Stride(); - for (MatrixIndexT i = 0; i < D; i++) { - for (MatrixIndexT j = 0; j <= i; j++) - dest[j] = src[j]; - dest += i + 1; - src += stride; - } - } - break; - case kTakeUpper: - for (MatrixIndexT i = 0; i < D; i++) - for (MatrixIndexT j = 0; j <= i; j++) - (*this)(i, j) = M(j, i); - break; - default: - KALDI_ASSERT("Invalid argument to SpMatrix::CopyFromMat"); - } -} - -template -Real SpMatrix::Trace() const { - const Real *data = this->data_; - MatrixIndexT num_rows = this->num_rows_; - Real ans = 0.0; - for (int32 i = 1; i <= num_rows; i++, data += i) - ans += *data; - return ans; -} - -// diagonal update, this <-- this + diag(v) -template -template -void SpMatrix::AddDiagVec(const Real alpha, const VectorBase &v) { - int32 num_rows = this->num_rows_; - KALDI_ASSERT(num_rows == v.Dim() && num_rows > 0); - const OtherReal *src = v.Data(); - Real *dst = this->data_; - if (alpha == 1.0) - for (int32 i = 1; i <= num_rows; i++, src++, dst += i) - *dst += *src; - else - for (int32 i = 1; i <= num_rows; i++, src++, dst += i) - *dst += alpha * *src; -} - -// instantiate the template above. -template -void SpMatrix::AddDiagVec(const float alpha, - const VectorBase &v); - -template -void SpMatrix::AddDiagVec(const double alpha, - const VectorBase &v); - -template -void SpMatrix::AddDiagVec(const float alpha, - const VectorBase &v); - -template -void SpMatrix::AddDiagVec(const double alpha, - const VectorBase &v); - -template<> -template<> -void SpMatrix::AddVec2(const double alpha, const VectorBase &v); - -#ifndef HAVE_ATLAS -template -void SpMatrix::Invert(Real *logdet, Real *det_sign, bool need_inverse) { - // these are CLAPACK types - KaldiBlasInt result; - KaldiBlasInt rows = static_cast(this->num_rows_); - KaldiBlasInt* p_ipiv = new KaldiBlasInt[rows]; - Real *p_work; // workspace for the lapack function - void *temp; - if ((p_work = static_cast( - KALDI_MEMALIGN(16, sizeof(Real) * rows, &temp))) == NULL) { - delete[] p_ipiv; - throw std::bad_alloc(); - } -#ifdef HAVE_OPENBLAS - memset(p_work, 0, sizeof(Real) * rows); // gets rid of a probably - // spurious Valgrind warning about jumps depending upon uninitialized values. -#endif - - - // NOTE: Even though "U" is for upper, lapack assumes column-wise storage - // of the data. We have a row-wise storage, therefore, we need to "invert" - clapack_Xsptrf(&rows, this->data_, p_ipiv, &result); - - - KALDI_ASSERT(result >= 0 && "Call to CLAPACK ssptrf_ called with wrong arguments"); - - if (result > 0) { // Singular... - if (det_sign) *det_sign = 0; - if (logdet) *logdet = -std::numeric_limits::infinity(); - if (need_inverse) KALDI_ERR << "CLAPACK stptrf_ : factorization failed"; - } else { // Not singular.. compute log-determinant if needed. - if (logdet != NULL || det_sign != NULL) { - Real prod = 1.0, log_prod = 0.0; - int sign = 1; - for (int i = 0; i < (int)this->num_rows_; i++) { - if (p_ipiv[i] > 0) { // not a 2x2 block... - // if (p_ipiv[i] != i+1) sign *= -1; // row swap. - Real diag = (*this)(i, i); - prod *= diag; - } else { // negative: 2x2 block. [we are in first of the two]. - i++; // skip over the first of the pair. - // each 2x2 block... - Real diag1 = (*this)(i, i), diag2 = (*this)(i-1, i-1), - offdiag = (*this)(i, i-1); - Real thisdet = diag1*diag2 - offdiag*offdiag; - // thisdet == determinant of 2x2 block. - // The following line is more complex than it looks: there are 2 offsets of - // 1 that cancel. - prod *= thisdet; - } - if (i == (int)(this->num_rows_-1) || fabs(prod) < 1.0e-10 || fabs(prod) > 1.0e+10) { - if (prod < 0) { prod = -prod; sign *= -1; } - log_prod += kaldi::Log(std::abs(prod)); - prod = 1.0; - } - } - if (logdet != NULL) *logdet = log_prod; - if (det_sign != NULL) *det_sign = sign; - } - } - if (!need_inverse) { - delete [] p_ipiv; - KALDI_MEMALIGN_FREE(p_work); - return; // Don't need what is computed next. - } - // NOTE: Even though "U" is for upper, lapack assumes column-wise storage - // of the data. We have a row-wise storage, therefore, we need to "invert" - clapack_Xsptri(&rows, this->data_, p_ipiv, p_work, &result); - - KALDI_ASSERT(result >=0 && - "Call to CLAPACK ssptri_ called with wrong arguments"); - - if (result != 0) { - KALDI_ERR << "CLAPACK ssptrf_ : Matrix is singular"; - } - - delete [] p_ipiv; - KALDI_MEMALIGN_FREE(p_work); -} -#else -// in the ATLAS case, these are not implemented using a library and we back off to something else. -template -void SpMatrix::Invert(Real *logdet, Real *det_sign, bool need_inverse) { - Matrix M(this->NumRows(), this->NumCols()); - M.CopyFromSp(*this); - M.Invert(logdet, det_sign, need_inverse); - if (need_inverse) - for (MatrixIndexT i = 0; i < this->NumRows(); i++) - for (MatrixIndexT j = 0; j <= i; j++) - (*this)(i, j) = M(i, j); -} -#endif - -template -void SpMatrix::InvertDouble(Real *logdet, Real *det_sign, - bool inverse_needed) { - SpMatrix dmat(*this); - double logdet_tmp, det_sign_tmp; - dmat.Invert(logdet ? &logdet_tmp : NULL, - det_sign ? &det_sign_tmp : NULL, - inverse_needed); - if (logdet) *logdet = logdet_tmp; - if (det_sign) *det_sign = det_sign_tmp; - (*this).CopyFromSp(dmat); -} - - - -double TraceSpSp(const SpMatrix &A, const SpMatrix &B) { - KALDI_ASSERT(A.NumRows() == B.NumRows()); - const double *Aptr = A.Data(); - const double *Bptr = B.Data(); - MatrixIndexT R = A.NumRows(); - MatrixIndexT RR = (R * (R + 1)) / 2; - double all_twice = 2.0 * cblas_Xdot(RR, Aptr, 1, Bptr, 1); - // "all_twice" contains twice the vector-wise dot-product... this is - // what we want except the diagonal elements are represented - // twice. - double diag_once = 0.0; - for (MatrixIndexT row_plus_two = 2; row_plus_two <= R + 1; row_plus_two++) { - diag_once += *Aptr * *Bptr; - Aptr += row_plus_two; - Bptr += row_plus_two; - } - return all_twice - diag_once; -} - - -float TraceSpSp(const SpMatrix &A, const SpMatrix &B) { - KALDI_ASSERT(A.NumRows() == B.NumRows()); - const float *Aptr = A.Data(); - const float *Bptr = B.Data(); - MatrixIndexT R = A.NumRows(); - MatrixIndexT RR = (R * (R + 1)) / 2; - float all_twice = 2.0 * cblas_Xdot(RR, Aptr, 1, Bptr, 1); - // "all_twice" contains twice the vector-wise dot-product... this is - // what we want except the diagonal elements are represented - // twice. - float diag_once = 0.0; - for (MatrixIndexT row_plus_two = 2; row_plus_two <= R + 1; row_plus_two++) { - diag_once += *Aptr * *Bptr; - Aptr += row_plus_two; - Bptr += row_plus_two; - } - return all_twice - diag_once; -} - - -template -Real TraceSpSp(const SpMatrix &A, const SpMatrix &B) { - KALDI_ASSERT(A.NumRows() == B.NumRows()); - Real ans = 0.0; - const Real *Aptr = A.Data(); - const OtherReal *Bptr = B.Data(); - MatrixIndexT row, col, R = A.NumRows(); - for (row = 0; row < R; row++) { - for (col = 0; col < row; col++) - ans += 2.0 * *(Aptr++) * *(Bptr++); - ans += *(Aptr++) * *(Bptr++); // Diagonal. - } - return ans; -} - -template -float TraceSpSp(const SpMatrix &A, const SpMatrix &B); - -template -double TraceSpSp(const SpMatrix &A, const SpMatrix &B); - - -template -Real TraceSpMat(const SpMatrix &A, const MatrixBase &B) { - KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols() && - "KALDI_ERR: TraceSpMat: arguments have mismatched dimension"); - MatrixIndexT R = A.NumRows(); - Real ans = (Real)0.0; - const Real *Aptr = A.Data(), *Bptr = B.Data(); - MatrixIndexT bStride = B.Stride(); - for (MatrixIndexT r = 0;r < R;r++) { - for (MatrixIndexT c = 0;c < r;c++) { - // ans += A(r, c) * (B(r, c) + B(c, r)); - ans += *(Aptr++) * (Bptr[r*bStride + c] + Bptr[c*bStride + r]); - } - // ans += A(r, r) * B(r, r); - ans += *(Aptr++) * Bptr[r*bStride + r]; - } - return ans; -} - -template -float TraceSpMat(const SpMatrix &A, const MatrixBase &B); - -template -double TraceSpMat(const SpMatrix &A, const MatrixBase &B); - - -template -Real TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC) { - KALDI_ASSERT((transA == kTrans?A.NumCols():A.NumRows()) == - (transC == kTrans?C.NumRows():C.NumCols()) && - (transA == kTrans?A.NumRows():A.NumCols()) == B.NumRows() && - (transC == kTrans?C.NumCols():C.NumRows()) == B.NumRows() && - "TraceMatSpMat: arguments have wrong dimension."); - Matrix tmp(B.NumRows(), B.NumRows()); - tmp.AddMatMat(1.0, C, transC, A, transA, 0.0); // tmp = C * A. - return TraceSpMat(B, tmp); -} - -template -float TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC); -template -double TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC); - -template -Real TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC, const SpMatrix &D) { - KALDI_ASSERT((transA == kTrans ?A.NumCols():A.NumRows() == D.NumCols()) && - (transA == kTrans ? A.NumRows():A.NumCols() == B.NumRows()) && - (transC == kTrans ? A.NumCols():A.NumRows() == B.NumCols()) && - (transC == kTrans ? A.NumRows():A.NumCols() == D.NumRows()) && - "KALDI_ERR: TraceMatSpMatSp: arguments have mismatched dimension."); - // Could perhaps optimize this more depending on dimensions of quantities. - Matrix tmpAB(transA == kTrans ? A.NumCols():A.NumRows(), B.NumCols()); - tmpAB.AddMatSp(1.0, A, transA, B, 0.0); - Matrix tmpCD(transC == kTrans ? C.NumCols():C.NumRows(), D.NumCols()); - tmpCD.AddMatSp(1.0, C, transC, D, 0.0); - return TraceMatMat(tmpAB, tmpCD, kNoTrans); -} - -template -float TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC, const SpMatrix &D); -template -double TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC, const SpMatrix &D); - - -template -bool SpMatrix::IsDiagonal(Real cutoff) const { - MatrixIndexT R = this->NumRows(); - Real bad_sum = 0.0, good_sum = 0.0; - for (MatrixIndexT i = 0; i < R; i++) { - for (MatrixIndexT j = 0; j <= i; j++) { - if (i == j) - good_sum += std::abs((*this)(i, j)); - else - bad_sum += std::abs((*this)(i, j)); - } - } - return (!(bad_sum > good_sum * cutoff)); -} - -template -bool SpMatrix::IsUnit(Real cutoff) const { - MatrixIndexT R = this->NumRows(); - Real max = 0.0; // max error - for (MatrixIndexT i = 0; i < R; i++) - for (MatrixIndexT j = 0; j <= i; j++) - max = std::max(max, static_cast(std::abs((*this)(i, j) - - (i == j ? 1.0 : 0.0)))); - return (max <= cutoff); -} - -template -bool SpMatrix::IsTridiagonal(Real cutoff) const { - MatrixIndexT R = this->NumRows(); - Real max_abs_2diag = 0.0, max_abs_offdiag = 0.0; - for (MatrixIndexT i = 0; i < R; i++) - for (MatrixIndexT j = 0; j <= i; j++) { - if (j+1 < i) - max_abs_offdiag = std::max(max_abs_offdiag, - std::abs((*this)(i, j))); - else - max_abs_2diag = std::max(max_abs_2diag, - std::abs((*this)(i, j))); - } - return (max_abs_offdiag <= cutoff * max_abs_2diag); -} - -template -bool SpMatrix::IsZero(Real cutoff) const { - if (this->num_rows_ == 0) return true; - return (this->Max() <= cutoff && this->Min() >= -cutoff); -} - -template -Real SpMatrix::FrobeniusNorm() const { - Real sum = 0.0; - MatrixIndexT R = this->NumRows(); - for (MatrixIndexT i = 0; i < R; i++) { - for (MatrixIndexT j = 0; j < i; j++) - sum += (*this)(i, j) * (*this)(i, j) * 2; - sum += (*this)(i, i) * (*this)(i, i); - } - return std::sqrt(sum); -} - -template -bool SpMatrix::ApproxEqual(const SpMatrix &other, float tol) const { - if (this->NumRows() != other.NumRows()) - KALDI_ERR << "SpMatrix::AproxEqual, size mismatch, " - << this->NumRows() << " vs. " << other.NumRows(); - SpMatrix tmp(*this); - tmp.AddSp(-1.0, other); - return (tmp.FrobeniusNorm() <= tol * std::max(this->FrobeniusNorm(), other.FrobeniusNorm())); -} - -// function Floor: A = Floor(B, alpha * C) ... see tutorial document. -template -int SpMatrix::ApplyFloor(const SpMatrix &C, Real alpha, - bool verbose) { - MatrixIndexT dim = this->NumRows(); - int nfloored = 0; - KALDI_ASSERT(C.NumRows() == dim); - KALDI_ASSERT(alpha > 0); - TpMatrix L(dim); - L.Cholesky(C); - L.Scale(std::sqrt(alpha)); // equivalent to scaling C by alpha. - TpMatrix LInv(L); - LInv.Invert(); - - SpMatrix D(dim); - { // D = L^{-1} * (*this) * L^{-T} - Matrix LInvFull(LInv); - D.AddMat2Sp(1.0, LInvFull, kNoTrans, (*this), 0.0); - } - - Vector l(dim); - Matrix U(dim, dim); - - D.Eig(&l, &U); - - if (verbose) { - KALDI_LOG << "ApplyFloor: flooring following diagonal to 1: " << l; - } - for (MatrixIndexT i = 0; i < l.Dim(); i++) { - if (l(i) < 1.0) { - nfloored++; - l(i) = 1.0; - } - } - l.ApplyPow(0.5); - U.MulColsVec(l); - D.AddMat2(1.0, U, kNoTrans, 0.0); - { // D' := U * diag(l') * U^T ... l'=floor(l, 1) - Matrix LFull(L); - (*this).AddMat2Sp(1.0, LFull, kNoTrans, D, 0.0); // A := L * D' * L^T - } - return nfloored; -} - -template -Real SpMatrix::LogDet(Real *det_sign) const { - Real log_det; - SpMatrix tmp(*this); - // false== output not needed (saves some computation). - tmp.Invert(&log_det, det_sign, false); - return log_det; -} - - -template -int SpMatrix::ApplyFloor(Real floor) { - MatrixIndexT Dim = this->NumRows(); - int nfloored = 0; - Vector s(Dim); - Matrix P(Dim, Dim); - (*this).Eig(&s, &P); - for (MatrixIndexT i = 0; i < Dim; i++) { - if (s(i) < floor) { - nfloored++; - s(i) = floor; - } - } - (*this).AddMat2Vec(1.0, P, kNoTrans, s, 0.0); - return nfloored; -} - -template -MatrixIndexT SpMatrix::LimitCond(Real maxCond, bool invert) { // e.g. maxCond = 1.0e+05. - MatrixIndexT Dim = this->NumRows(); - Vector s(Dim); - Matrix P(Dim, Dim); - (*this).SymPosSemiDefEig(&s, &P); - KALDI_ASSERT(maxCond > 1); - Real floor = s.Max() / maxCond; - if (floor < 0) floor = 0; - if (floor < 1.0e-40) { - KALDI_WARN << "LimitCond: limiting " << floor << " to 1.0e-40"; - floor = 1.0e-40; - } - MatrixIndexT nfloored = 0; - for (MatrixIndexT i = 0; i < Dim; i++) { - if (s(i) <= floor) nfloored++; - if (invert) - s(i) = 1.0 / std::sqrt(std::max(s(i), floor)); - else - s(i) = std::sqrt(std::max(s(i), floor)); - } - P.MulColsVec(s); - (*this).AddMat2(1.0, P, kNoTrans, 0.0); // (*this) = P*P^T. ... (*this) = P * floor(s) * P^T ... if P was original P. - return nfloored; -} - -void SolverOptions::Check() const { - KALDI_ASSERT(K>10 && eps<1.0e-10); -} - -template<> double SolveQuadraticProblem(const SpMatrix &H, - const VectorBase &g, - const SolverOptions &opts, - VectorBase *x) { - KALDI_ASSERT(H.NumRows() == g.Dim() && g.Dim() == x->Dim() && x->Dim() != 0); - opts.Check(); - MatrixIndexT dim = x->Dim(); - if (H.IsZero(0.0)) { - KALDI_WARN << "Zero quadratic term in quadratic vector problem for " - << opts.name << ": leaving it unchanged."; - return 0.0; - } - if (opts.diagonal_precondition) { - // We can re-cast the problem with a diagonal preconditioner to - // make H better-conditioned. - Vector H_diag(dim); - H_diag.CopyDiagFromSp(H); - H_diag.ApplyFloor(std::numeric_limits::min() * 1.0E+3); - Vector H_diag_sqrt(H_diag); - H_diag_sqrt.ApplyPow(0.5); - Vector H_diag_inv_sqrt(H_diag_sqrt); - H_diag_inv_sqrt.InvertElements(); - Vector x_scaled(*x); - x_scaled.MulElements(H_diag_sqrt); - Vector g_scaled(g); - g_scaled.MulElements(H_diag_inv_sqrt); - SpMatrix H_scaled(dim); - H_scaled.AddVec2Sp(1.0, H_diag_inv_sqrt, H, 0.0); - double ans; - SolverOptions new_opts(opts); - new_opts.diagonal_precondition = false; - ans = SolveQuadraticProblem(H_scaled, g_scaled, new_opts, &x_scaled); - x->CopyFromVec(x_scaled); - x->MulElements(H_diag_inv_sqrt); - return ans; - } - Vector gbar(g); - if (opts.optimize_delta) gbar.AddSpVec(-1.0, H, *x, 1.0); // gbar = g - H x - Matrix U(dim, dim); - Vector l(dim); - H.SymPosSemiDefEig(&l, &U); // does svd H = U L V^T and checks that H == U L U^T to within a tolerance. - // floor l. - double f = std::max(static_cast(opts.eps), l.Max() / opts.K); - MatrixIndexT nfloored = 0; - for (MatrixIndexT i = 0; i < dim; i++) { // floor l. - if (l(i) < f) { - nfloored++; - l(i) = f; - } - } - if (nfloored != 0 && opts.print_debug_output) { - KALDI_LOG << "Solving quadratic problem for " << opts.name - << ": floored " << nfloored<< " eigenvalues. "; - } - Vector tmp(dim); - tmp.AddMatVec(1.0, U, kTrans, gbar, 0.0); // tmp = U^T \bar{g} - tmp.DivElements(l); // divide each element of tmp by l: tmp = \tilde{L}^{-1} U^T \bar{g} - Vector delta(dim); - delta.AddMatVec(1.0, U, kNoTrans, tmp, 0.0); // delta = U tmp = U \tilde{L}^{-1} U^T \bar{g} - Vector &xhat(tmp); - xhat.CopyFromVec(delta); - if (opts.optimize_delta) xhat.AddVec(1.0, *x); // xhat = x + delta. - double auxf_before = VecVec(g, *x) - 0.5 * VecSpVec(*x, H, *x), - auxf_after = VecVec(g, xhat) - 0.5 * VecSpVec(xhat, H, xhat); - if (auxf_after < auxf_before) { // Reject change. - if (auxf_after < auxf_before - 1.0e-10 && opts.print_debug_output) - KALDI_WARN << "Optimizing vector auxiliary function for " - << opts.name<< ": auxf decreased " << auxf_before - << " to " << auxf_after << ", change is " - << (auxf_after-auxf_before); - return 0.0; - } else { - x->CopyFromVec(xhat); - return auxf_after - auxf_before; - } -} - -template<> float SolveQuadraticProblem(const SpMatrix &H, - const VectorBase &g, - const SolverOptions &opts, - VectorBase *x) { - KALDI_ASSERT(H.NumRows() == g.Dim() && g.Dim() == x->Dim() && x->Dim() != 0); - SpMatrix Hd(H); - Vector gd(g); - Vector xd(*x); - float ans = static_cast(SolveQuadraticProblem(Hd, gd, opts, &xd)); - x->CopyFromVec(xd); - return ans; -} - -// Maximizes the auxiliary function Q(x) = tr(M^T SigmaInv Y) - 0.5 tr(SigmaInv M Q M^T). -// Like a numerically stable version of M := Y Q^{-1}. -template -Real -SolveQuadraticMatrixProblem(const SpMatrix &Q, - const MatrixBase &Y, - const SpMatrix &SigmaInv, - const SolverOptions &opts, - MatrixBase *M) { - KALDI_ASSERT(Q.NumRows() == M->NumCols() && - SigmaInv.NumRows() == M->NumRows() && Y.NumRows() == M->NumRows() - && Y.NumCols() == M->NumCols() && M->NumCols() != 0); - opts.Check(); - MatrixIndexT rows = M->NumRows(), cols = M->NumCols(); - if (Q.IsZero(0.0)) { - KALDI_WARN << "Zero quadratic term in quadratic matrix problem for " - << opts.name << ": leaving it unchanged."; - return 0.0; - } - - if (opts.diagonal_precondition) { - // We can re-cast the problem with a diagonal preconditioner in the space - // of Q (columns of M). Helps to improve the condition of Q. - Vector Q_diag(cols); - Q_diag.CopyDiagFromSp(Q); - Q_diag.ApplyFloor(std::numeric_limits::min() * 1.0E+3); - Vector Q_diag_sqrt(Q_diag); - Q_diag_sqrt.ApplyPow(0.5); - Vector Q_diag_inv_sqrt(Q_diag_sqrt); - Q_diag_inv_sqrt.InvertElements(); - Matrix M_scaled(*M); - M_scaled.MulColsVec(Q_diag_sqrt); - Matrix Y_scaled(Y); - Y_scaled.MulColsVec(Q_diag_inv_sqrt); - SpMatrix Q_scaled(cols); - Q_scaled.AddVec2Sp(1.0, Q_diag_inv_sqrt, Q, 0.0); - Real ans; - SolverOptions new_opts(opts); - new_opts.diagonal_precondition = false; - ans = SolveQuadraticMatrixProblem(Q_scaled, Y_scaled, SigmaInv, - new_opts, &M_scaled); - M->CopyFromMat(M_scaled); - M->MulColsVec(Q_diag_inv_sqrt); - return ans; - } - - Matrix Ybar(Y); - if (opts.optimize_delta) { - Matrix Qfull(Q); - Ybar.AddMatMat(-1.0, *M, kNoTrans, Qfull, kNoTrans, 1.0); - } // Ybar = Y - M Q. - Matrix U(cols, cols); - Vector l(cols); - Q.SymPosSemiDefEig(&l, &U); // does svd Q = U L V^T and checks that Q == U L U^T to within a tolerance. - // floor l. - Real f = std::max(static_cast(opts.eps), l.Max() / opts.K); - MatrixIndexT nfloored = 0; - for (MatrixIndexT i = 0; i < cols; i++) { // floor l. - if (l(i) < f) { nfloored++; l(i) = f; } - } - if (nfloored != 0 && opts.print_debug_output) - KALDI_LOG << "Solving matrix problem for " << opts.name - << ": floored " << nfloored << " eigenvalues. "; - Matrix tmpDelta(rows, cols); - tmpDelta.AddMatMat(1.0, Ybar, kNoTrans, U, kNoTrans, 0.0); // tmpDelta = Ybar * U. - l.InvertElements(); KALDI_ASSERT(1.0/l.Max() != 0); // check not infinite. eps should take care of this. - tmpDelta.MulColsVec(l); // tmpDelta = Ybar * U * \tilde{L}^{-1} - - Matrix Delta(rows, cols); - Delta.AddMatMat(1.0, tmpDelta, kNoTrans, U, kTrans, 0.0); // Delta = Ybar * U * \tilde{L}^{-1} * U^T - - Real auxf_before, auxf_after; - SpMatrix MQM(rows); - Matrix &SigmaInvY(tmpDelta); - { Matrix SigmaInvFull(SigmaInv); SigmaInvY.AddMatMat(1.0, SigmaInvFull, kNoTrans, Y, kNoTrans, 0.0); } - { // get auxf_before. Q(x) = tr(M^T SigmaInv Y) - 0.5 tr(SigmaInv M Q M^T). - MQM.AddMat2Sp(1.0, *M, kNoTrans, Q, 0.0); - auxf_before = TraceMatMat(*M, SigmaInvY, kaldi::kTrans) - 0.5*TraceSpSp(SigmaInv, MQM); - } - - Matrix Mhat(Delta); - if (opts.optimize_delta) Mhat.AddMat(1.0, *M); // Mhat = Delta + M. - - { // get auxf_after. - MQM.AddMat2Sp(1.0, Mhat, kNoTrans, Q, 0.0); - auxf_after = TraceMatMat(Mhat, SigmaInvY, kaldi::kTrans) - 0.5*TraceSpSp(SigmaInv, MQM); - } - - if (auxf_after < auxf_before) { - if (auxf_after < auxf_before - 1.0e-10) - KALDI_WARN << "Optimizing matrix auxiliary function for " - << opts.name << ", auxf decreased " - << auxf_before << " to " << auxf_after << ", change is " - << (auxf_after-auxf_before); - return 0.0; - } else { - M->CopyFromMat(Mhat); - return auxf_after - auxf_before; - } -} - -template -Real SolveDoubleQuadraticMatrixProblem(const MatrixBase &G, - const SpMatrix &P1, - const SpMatrix &P2, - const SpMatrix &Q1, - const SpMatrix &Q2, - const SolverOptions &opts, - MatrixBase *M) { - KALDI_ASSERT(Q1.NumRows() == M->NumCols() && P1.NumRows() == M->NumRows() && - G.NumRows() == M->NumRows() && G.NumCols() == M->NumCols() && - M->NumCols() != 0 && Q2.NumRows() == M->NumCols() && - P2.NumRows() == M->NumRows()); - MatrixIndexT rows = M->NumRows(), cols = M->NumCols(); - // The following check should not fail as we stipulate P1, P2 and one of Q1 - // or Q2 must be +ve def and other Q1 or Q2 must be +ve semidef. - TpMatrix LInv(rows); - LInv.Cholesky(P1); - LInv.Invert(); // Will throw exception if fails. - SpMatrix S(rows); - Matrix LInvFull(LInv); - S.AddMat2Sp(1.0, LInvFull, kNoTrans, P2, 0.0); // S := L^{-1} P_2 L^{-T} - Matrix U(rows, rows); - Vector d(rows); - S.SymPosSemiDefEig(&d, &U); - Matrix T(rows, rows); - T.AddMatMat(1.0, U, kTrans, LInvFull, kNoTrans, 0.0); // T := U^T * L^{-1} - -#ifdef KALDI_PARANOID // checking mainly for errors in the code or math. - { - SpMatrix P1Trans(rows); - P1Trans.AddMat2Sp(1.0, T, kNoTrans, P1, 0.0); - KALDI_ASSERT(P1Trans.IsUnit(0.01)); - } - { - SpMatrix P2Trans(rows); - P2Trans.AddMat2Sp(1.0, T, kNoTrans, P2, 0.0); - KALDI_ASSERT(P2Trans.IsDiagonal(0.01)); - } -#endif - - Matrix TInv(T); - TInv.Invert(); - Matrix Gdash(rows, cols); - Gdash.AddMatMat(1.0, T, kNoTrans, G, kNoTrans, 0.0); // G' = T G - Matrix MdashOld(rows, cols); - MdashOld.AddMatMat(1.0, TInv, kTrans, *M, kNoTrans, 0.0); // M' = T^{-T} M - Matrix MdashNew(MdashOld); - Real objf_impr = 0.0; - for (MatrixIndexT n = 0; n < rows; n++) { - SpMatrix Qsum(Q1); - Qsum.AddSp(d(n), Q2); - SubVector mdash_n = MdashNew.Row(n); - SubVector gdash_n = Gdash.Row(n); - - Matrix QsumInv(Qsum); - try { - QsumInv.Invert(); - Real old_objf = VecVec(mdash_n, gdash_n) - - 0.5 * VecSpVec(mdash_n, Qsum, mdash_n); - mdash_n.AddMatVec(1.0, QsumInv, kNoTrans, gdash_n, 0.0); // m'_n := g'_n * (Q_1 + d_n Q_2)^{-1} - Real new_objf = VecVec(mdash_n, gdash_n) - - 0.5 * VecSpVec(mdash_n, Qsum, mdash_n); - if (new_objf < old_objf) { - if (new_objf < old_objf - 1.0e-05) { - KALDI_WARN << "In double quadratic matrix problem: objective " - "function decreasing during optimization of " << opts.name - << ", " << old_objf << "->" << new_objf << ", change is " - << (new_objf - old_objf); - KALDI_ERR << "Auxiliary function decreasing."; // Will be caught. - } else { // Reset to old value, didn't improve (very close to optimum). - MdashNew.Row(n).CopyFromVec(MdashOld.Row(n)); - } - } - objf_impr += new_objf - old_objf; - } - catch (...) { - KALDI_WARN << "Matrix inversion or optimization failed during double " - "quadratic problem, solving for" << opts.name - << ": trying more stable approach."; - objf_impr += SolveQuadraticProblem(Qsum, gdash_n, opts, &mdash_n); - } - } - M->AddMatMat(1.0, T, kTrans, MdashNew, kNoTrans, 0.0); // M := T^T M'. - return objf_impr; -} - -// rank-one update, this <-- this + alpha V V' -template<> -template<> -void SpMatrix::AddVec2(const float alpha, const VectorBase &v) { - KALDI_ASSERT(v.Dim() == this->NumRows()); - cblas_Xspr(v.Dim(), alpha, v.Data(), 1, - this->data_); -} - -template -void SpMatrix::AddVec2Sp(const Real alpha, const VectorBase &v, - const SpMatrix &S, const Real beta) { - KALDI_ASSERT(v.Dim() == this->NumRows() && S.NumRows() == this->NumRows()); - const Real *Sdata = S.Data(); - const Real *vdata = v.Data(); - Real *data = this->data_; - MatrixIndexT dim = this->num_rows_; - for (MatrixIndexT r = 0; r < dim; r++) - for (MatrixIndexT c = 0; c <= r; c++, Sdata++, data++) - *data = beta * *data + alpha * vdata[r] * vdata[c] * *Sdata; -} - - -// rank-one update, this <-- this + alpha V V' -template<> -template<> -void SpMatrix::AddVec2(const double alpha, const VectorBase &v) { - KALDI_ASSERT(v.Dim() == num_rows_); - cblas_Xspr(v.Dim(), alpha, v.Data(), 1, data_); -} - - -template -template -void SpMatrix::AddVec2(const Real alpha, const VectorBase &v) { - KALDI_ASSERT(v.Dim() == this->NumRows()); - Real *data = this->data_; - const OtherReal *v_data = v.Data(); - MatrixIndexT nr = this->num_rows_; - for (MatrixIndexT i = 0; i < nr; i++) - for (MatrixIndexT j = 0; j <= i; j++, data++) - *data += alpha * v_data[i] * v_data[j]; -} - -// instantiate the template above. -template -void SpMatrix::AddVec2(const float alpha, const VectorBase &v); -template -void SpMatrix::AddVec2(const double alpha, const VectorBase &v); - - -template -Real VecSpVec(const VectorBase &v1, const SpMatrix &M, - const VectorBase &v2) { - MatrixIndexT D = M.NumRows(); - KALDI_ASSERT(v1.Dim() == D && v1.Dim() == v2.Dim()); - Vector tmp_vec(D); - cblas_Xspmv(D, 1.0, M.Data(), v1.Data(), 1, 0.0, tmp_vec.Data(), 1); - return VecVec(tmp_vec, v2); -} - -template -float VecSpVec(const VectorBase &v1, const SpMatrix &M, - const VectorBase &v2); -template -double VecSpVec(const VectorBase &v1, const SpMatrix &M, - const VectorBase &v2); - - -template -void SpMatrix::AddMat2Sp( - const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const SpMatrix &A, const Real beta) { - if (transM == kNoTrans) { - KALDI_ASSERT(M.NumCols() == A.NumRows() && M.NumRows() == this->num_rows_); - } else { - KALDI_ASSERT(M.NumRows() == A.NumRows() && M.NumCols() == this->num_rows_); - } - Vector tmp_vec(A.NumRows()); - Real *tmp_vec_data = tmp_vec.Data(); - SpMatrix tmp_A; - const Real *p_A_data = A.Data(); - Real *p_row_data = this->Data(); - MatrixIndexT M_other_dim = (transM == kNoTrans ? M.NumCols() : M.NumRows()), - M_same_dim = (transM == kNoTrans ? M.NumRows() : M.NumCols()), - M_stride = M.Stride(), dim = this->NumRows(); - KALDI_ASSERT(M_same_dim == dim); - - const Real *M_data = M.Data(); - - if (this->Data() <= A.Data() + A.SizeInBytes() && - this->Data() + this->SizeInBytes() >= A.Data()) { - // Matrices A and *this overlap. Make copy of A - tmp_A.Resize(A.NumRows()); - tmp_A.CopyFromSp(A); - p_A_data = tmp_A.Data(); - } - - if (transM == kNoTrans) { - for (MatrixIndexT r = 0; r < dim; r++, p_row_data += r) { - cblas_Xspmv(A.NumRows(), 1.0, p_A_data, M.RowData(r), 1, 0.0, tmp_vec_data, 1); - cblas_Xgemv(transM, r+1, M_other_dim, alpha, M_data, M_stride, - tmp_vec_data, 1, beta, p_row_data, 1); - } - } else { - for (MatrixIndexT r = 0; r < dim; r++, p_row_data += r) { - cblas_Xspmv(A.NumRows(), 1.0, p_A_data, M.Data() + r, M.Stride(), 0.0, tmp_vec_data, 1); - cblas_Xgemv(transM, M_other_dim, r+1, alpha, M_data, M_stride, - tmp_vec_data, 1, beta, p_row_data, 1); - } - } -} - -template -void SpMatrix::AddSmat2Sp( - const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const SpMatrix &A, - const Real beta) { - KALDI_ASSERT((transM == kNoTrans && M.NumCols() == A.NumRows()) || - (transM == kTrans && M.NumRows() == A.NumRows())); - if (transM == kNoTrans) { - KALDI_ASSERT(M.NumCols() == A.NumRows() && M.NumRows() == this->num_rows_); - } else { - KALDI_ASSERT(M.NumRows() == A.NumRows() && M.NumCols() == this->num_rows_); - } - MatrixIndexT Adim = A.NumRows(), dim = this->num_rows_; - - Matrix temp_A(A); // represent A as full matrix. - Matrix temp_MA(dim, Adim); - temp_MA.AddSmatMat(1.0, M, transM, temp_A, kNoTrans, 0.0); - - // Next-- we want to do *this = alpha * temp_MA * M^T + beta * *this. - // To make it sparse vector multiplies, since M is sparse, we'd like - // to do: for each column c, (*this column c) += temp_MA * (M^T's column c.) - // [ignoring the alpha and beta here.] - // It's not convenient to process columns in the symmetric - // packed format because they don't have a constant stride. However, - // we can use the fact that temp_MA * M is symmetric, to just assign - // each row of *this instead of each column. - // So the final iteration is: - // for i = 0... dim-1, - // [the i'th row of *this] = beta * [the i'th row of *this] + alpha * - // temp_MA * [the i'th column of M]. - // Of course, we only process the first 0 ... i elements of this row, - // as that's all that are kept in the symmetric packed format. - - Matrix temp_this(*this); - Real *data = this->data_; - const Real *Mdata = M.Data(), *MAdata = temp_MA.Data(); - MatrixIndexT temp_MA_stride = temp_MA.Stride(), Mstride = M.Stride(); - - if (transM == kNoTrans) { - // The column of M^T corresponds to the rows of the supplied matrix. - for (MatrixIndexT i = 0; i < dim; i++, data += i) { - MatrixIndexT num_rows = i + 1, num_cols = Adim; - Xgemv_sparsevec(kNoTrans, num_rows, num_cols, alpha, MAdata, - temp_MA_stride, Mdata + (i * Mstride), 1, beta, data, 1); - } - } else { - // The column of M^T corresponds to the columns of the supplied matrix. - for (MatrixIndexT i = 0; i < dim; i++, data += i) { - MatrixIndexT num_rows = i + 1, num_cols = Adim; - Xgemv_sparsevec(kNoTrans, num_rows, num_cols, alpha, MAdata, - temp_MA_stride, Mdata + i, Mstride, beta, data, 1); - } - } -} - -template -void SpMatrix::AddMat2Vec(const Real alpha, - const MatrixBase &M, - MatrixTransposeType transM, - const VectorBase &v, - const Real beta) { - this->Scale(beta); - KALDI_ASSERT((transM == kNoTrans && this->NumRows() == M.NumRows() && - M.NumCols() == v.Dim()) || - (transM == kTrans && this->NumRows() == M.NumCols() && - M.NumRows() == v.Dim())); - - if (transM == kNoTrans) { - const Real *Mdata = M.Data(), *vdata = v.Data(); - Real *data = this->data_; - MatrixIndexT dim = this->NumRows(), mcols = M.NumCols(), - mstride = M.Stride(); - for (MatrixIndexT col = 0; col < mcols; col++, vdata++, Mdata += 1) - cblas_Xspr(dim, *vdata*alpha, Mdata, mstride, data); - } else { - const Real *Mdata = M.Data(), *vdata = v.Data(); - Real *data = this->data_; - MatrixIndexT dim = this->NumRows(), mrows = M.NumRows(), - mstride = M.Stride(); - for (MatrixIndexT row = 0; row < mrows; row++, vdata++, Mdata += mstride) - cblas_Xspr(dim, *vdata*alpha, Mdata, 1, data); - } -} - -template -void SpMatrix::AddMat2(const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const Real beta) { - KALDI_ASSERT((transM == kNoTrans && this->NumRows() == M.NumRows()) - || (transM == kTrans && this->NumRows() == M.NumCols())); - - // Cblas has no function *sprk (i.e. symmetric packed rank-k update), so we - // use as temporary storage a regular matrix of which we only access its lower - // triangle - - MatrixIndexT this_dim = this->NumRows(), - m_other_dim = (transM == kNoTrans ? M.NumCols() : M.NumRows()); - - if (this_dim == 0) return; - if (alpha == 0.0) { - if (beta != 1.0) this->Scale(beta); - return; - } - - Matrix temp_mat(*this); // wastefully copies upper triangle too, but this - // doesn't dominate O(N) time. - - // This function call is hard-coded to update the lower triangle. - cblas_Xsyrk(transM, this_dim, m_other_dim, alpha, M.Data(), - M.Stride(), beta, temp_mat.Data(), temp_mat.Stride()); - - this->CopyFromMat(temp_mat, kTakeLower); -} - -template -void SpMatrix::AddTp2Sp(const Real alpha, const TpMatrix &T, - MatrixTransposeType transM, const SpMatrix &A, - const Real beta) { - Matrix Tmat(T); - AddMat2Sp(alpha, Tmat, transM, A, beta); -} - -template -void SpMatrix::AddVecVec(const Real alpha, const VectorBase &v, - const VectorBase &w) { - int32 dim = this->NumRows(); - KALDI_ASSERT(dim == v.Dim() && dim == w.Dim() && dim > 0); - cblas_Xspr2(dim, alpha, v.Data(), 1, w.Data(), 1, this->data_); -} - - -template -void SpMatrix::AddTp2(const Real alpha, const TpMatrix &T, - MatrixTransposeType transM, const Real beta) { - Matrix Tmat(T); - AddMat2(alpha, Tmat, transM, beta); -} - - -// Explicit instantiation of the class. -// This needs to be after the definition of all the class member functions. - -template class SpMatrix; -template class SpMatrix; - - -template -Real TraceSpSpLower(const SpMatrix &A, const SpMatrix &B) { - MatrixIndexT adim = A.NumRows(); - KALDI_ASSERT(adim == B.NumRows()); - MatrixIndexT dim = (adim*(adim+1))/2; - return cblas_Xdot(dim, A.Data(), 1, B.Data(), 1); -} -// Instantiate the template above. -template -double TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); -template -float TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); - -// Instantiate the template above. -template float SolveQuadraticMatrixProblem(const SpMatrix &Q, - const MatrixBase &Y, - const SpMatrix &SigmaInv, - const SolverOptions &opts, - MatrixBase *M); -template double SolveQuadraticMatrixProblem(const SpMatrix &Q, - const MatrixBase &Y, - const SpMatrix &SigmaInv, - const SolverOptions &opts, - MatrixBase *M); - -// Instantiate the template above. -template float SolveDoubleQuadraticMatrixProblem( - const MatrixBase &G, - const SpMatrix &P1, - const SpMatrix &P2, - const SpMatrix &Q1, - const SpMatrix &Q2, - const SolverOptions &opts, - MatrixBase *M); - -template double SolveDoubleQuadraticMatrixProblem( - const MatrixBase &G, - const SpMatrix &P1, - const SpMatrix &P2, - const SpMatrix &Q1, - const SpMatrix &Q2, - const SolverOptions &opts, - MatrixBase *M); - - - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sp-matrix.h b/speechx/speechx/kaldi/matrix/sp-matrix.h deleted file mode 100644 index 26d9ad6f..00000000 --- a/speechx/speechx/kaldi/matrix/sp-matrix.h +++ /dev/null @@ -1,517 +0,0 @@ -// matrix/sp-matrix.h - -// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; -// Saarland University; Ariya Rastrow; Yanmin Qian; -// Jan Silovsky - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -#ifndef KALDI_MATRIX_SP_MATRIX_H_ -#define KALDI_MATRIX_SP_MATRIX_H_ - -#include -#include - -#include "matrix/packed-matrix.h" - -namespace kaldi { - - -/// \addtogroup matrix_group -/// @{ -template class SpMatrix; - - -/** - * @brief Packed symetric matrix class -*/ -template -class SpMatrix : public PackedMatrix { - friend class CuSpMatrix; - public: - // so it can use our assignment operator. - friend class std::vector >; - - SpMatrix(): PackedMatrix() {} - - /// Copy constructor from CUDA version of SpMatrix - /// This is defined in ../cudamatrix/cu-sp-matrix.h - - explicit SpMatrix(const CuSpMatrix &cu); - - explicit SpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) - : PackedMatrix(r, resize_type) {} - - SpMatrix(const SpMatrix &orig) - : PackedMatrix(orig) {} - - template - explicit SpMatrix(const SpMatrix &orig) - : PackedMatrix(orig) {} - -#ifdef KALDI_PARANOID - explicit SpMatrix(const MatrixBase & orig, - SpCopyType copy_type = kTakeMeanAndCheck) - : PackedMatrix(orig.NumRows(), kUndefined) { - CopyFromMat(orig, copy_type); - } -#else - explicit SpMatrix(const MatrixBase & orig, - SpCopyType copy_type = kTakeMean) - : PackedMatrix(orig.NumRows(), kUndefined) { - CopyFromMat(orig, copy_type); - } -#endif - - /// Shallow swap. - void Swap(SpMatrix *other); - - inline void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { - PackedMatrix::Resize(nRows, resize_type); - } - - void CopyFromSp(const SpMatrix &other) { - PackedMatrix::CopyFromPacked(other); - } - - template - void CopyFromSp(const SpMatrix &other) { - PackedMatrix::CopyFromPacked(other); - } - -#ifdef KALDI_PARANOID - void CopyFromMat(const MatrixBase &orig, - SpCopyType copy_type = kTakeMeanAndCheck); -#else // different default arg if non-paranoid mode. - void CopyFromMat(const MatrixBase &orig, - SpCopyType copy_type = kTakeMean); -#endif - - inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { - // if column is less than row, then swap these as matrix is stored - // as upper-triangular... only allowed for const matrix object. - if (static_cast(c) > - static_cast(r)) - std::swap(c, r); - // c<=r now so don't have to check c. - KALDI_ASSERT(static_cast(r) < - static_cast(this->num_rows_)); - return *(this->data_ + (r*(r+1)) / 2 + c); - // Duplicating code from PackedMatrix.h - } - - inline Real &operator() (MatrixIndexT r, MatrixIndexT c) { - if (static_cast(c) > - static_cast(r)) - std::swap(c, r); - // c<=r now so don't have to check c. - KALDI_ASSERT(static_cast(r) < - static_cast(this->num_rows_)); - return *(this->data_ + (r * (r + 1)) / 2 + c); - // Duplicating code from PackedMatrix.h - } - - SpMatrix& operator=(const SpMatrix &other) { - PackedMatrix::operator=(other); - return *this; - } - - using PackedMatrix::Scale; - - /// matrix inverse. - /// if inverse_needed = false, will fill matrix with garbage. - /// (only useful if logdet wanted). - void Invert(Real *logdet = NULL, Real *det_sign= NULL, - bool inverse_needed = true); - - // Below routine does inversion in double precision, - // even for single-precision object. - void InvertDouble(Real *logdet = NULL, Real *det_sign = NULL, - bool inverse_needed = true); - - /// Returns maximum ratio of singular values. - inline Real Cond() const { - Matrix tmp(*this); - return tmp.Cond(); - } - - /// Takes matrix to a fraction power via Svd. - /// Will throw exception if matrix is not positive semidefinite - /// (to within a tolerance) - void ApplyPow(Real exponent); - - /// This is the version of SVD that we implement for symmetric positive - /// definite matrices. This exists for historical reasons; right now its - /// internal implementation is the same as Eig(). It computes the eigenvalue - /// decomposition (*this) = P * diag(s) * P^T with P orthogonal. Will throw - /// exception if input is not positive semidefinite to within a tolerance. - void SymPosSemiDefEig(VectorBase *s, MatrixBase *P, - Real tolerance = 0.001) const; - - /// Solves the symmetric eigenvalue problem: at end we should have (*this) = P - /// * diag(s) * P^T. We solve the problem using the symmetric QR method. - /// P may be NULL. - /// Implemented in qr.cc. - /// If you need the eigenvalues sorted, the function SortSvd declared in - /// kaldi-matrix is suitable. - void Eig(VectorBase *s, MatrixBase *P = NULL) const; - - /// This function gives you, approximately, the largest eigenvalues of the - /// symmetric matrix and the corresponding eigenvectors. (largest meaning, - /// further from zero). It does this by doing a SVD within the Krylov - /// subspace generated by this matrix and a random vector. This is - /// a form of the Lanczos method with complete reorthogonalization, followed - /// by SVD within a smaller dimension ("lanczos_dim"). - /// - /// If *this is m by m, s should be of dimension n and P should be of - /// dimension m by n, with n <= m. The *columns* of P are the approximate - /// eigenvectors; P * diag(s) * P^T would be a low-rank reconstruction of - /// *this. The columns of P will be orthogonal, and the elements of s will be - /// the eigenvalues of *this projected into that subspace, but beyond that - /// there are no exact guarantees. (This is because the convergence of this - /// method is statistical). Note: it only makes sense to use this - /// method if you are in very high dimension and n is substantially smaller - /// than m: for example, if you want the 100 top eigenvalues of a 10k by 10k - /// matrix. This function calls Rand() to initialize the lanczos - /// iterations and also for restarting. - /// If lanczos_dim is zero, it will default to the greater of: - /// s->Dim() + 50 or s->Dim() + s->Dim()/2, but not more than this->Dim(). - /// If lanczos_dim == this->Dim(), you might as well just call the function - /// Eig() since the result will be the same, and Eig() would be faster; the - /// whole point of this function is to reduce the dimension of the SVD - /// computation. - void TopEigs(VectorBase *s, MatrixBase *P, - MatrixIndexT lanczos_dim = 0) const; - - - /// Returns the maximum of the absolute values of any of the - /// eigenvalues. - Real MaxAbsEig() const; - - void PrintEigs(const char *name) { - Vector s((*this).NumRows()); - Matrix P((*this).NumRows(), (*this).NumCols()); - SymPosSemiDefEig(&s, &P); - KALDI_LOG << "PrintEigs: " << name << ": " << s; - } - - bool IsPosDef() const; // returns true if Cholesky succeeds. - void AddSp(const Real alpha, const SpMatrix &Ma) { - this->AddPacked(alpha, Ma); - } - - /// Computes log determinant but only for +ve-def matrices - /// (it uses Cholesky). - /// If matrix is not +ve-def, it will throw an exception - /// was LogPDDeterminant() - Real LogPosDefDet() const; - - Real LogDet(Real *det_sign = NULL) const; - - /// rank-one update, this <-- this + alpha v v' - template - void AddVec2(const Real alpha, const VectorBase &v); - - /// rank-two update, this <-- this + alpha (v w' + w v'). - void AddVecVec(const Real alpha, const VectorBase &v, - const VectorBase &w); - - /// Does *this = beta * *thi + alpha * diag(v) * S * diag(v) - void AddVec2Sp(const Real alpha, const VectorBase &v, - const SpMatrix &S, const Real beta); - - /// diagonal update, this <-- this + diag(v) - template - void AddDiagVec(const Real alpha, const VectorBase &v); - - /// rank-N update: - /// if (transM == kNoTrans) - /// (*this) = beta*(*this) + alpha * M * M^T, - /// or (if transM == kTrans) - /// (*this) = beta*(*this) + alpha * M^T * M - /// Note: beta used to default to 0.0. - void AddMat2(const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const Real beta); - - /// Extension of rank-N update: - /// this <-- beta*this + alpha * M * A * M^T. - /// (*this) and A are allowed to be the same. - /// If transM == kTrans, then we do it as M^T * A * M. - void AddMat2Sp(const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const SpMatrix &A, - const Real beta = 0.0); - - /// This is a version of AddMat2Sp specialized for when M is fairly sparse. - /// This was required for making the raw-fMLLR code efficient. - void AddSmat2Sp(const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const SpMatrix &A, - const Real beta = 0.0); - - /// The following function does: - /// this <-- beta*this + alpha * T * A * T^T. - /// (*this) and A are allowed to be the same. - /// If transM == kTrans, then we do it as alpha * T^T * A * T. - /// Currently it just calls AddMat2Sp, but if needed we - /// can implement it more efficiently. - void AddTp2Sp(const Real alpha, const TpMatrix &T, - MatrixTransposeType transM, const SpMatrix &A, - const Real beta = 0.0); - - /// The following function does: - /// this <-- beta*this + alpha * T * T^T. - /// (*this) and A are allowed to be the same. - /// If transM == kTrans, then we do it as alpha * T^T * T - /// Currently it just calls AddMat2, but if needed we - /// can implement it more efficiently. - void AddTp2(const Real alpha, const TpMatrix &T, - MatrixTransposeType transM, const Real beta = 0.0); - - /// Extension of rank-N update: - /// this <-- beta*this + alpha * M * diag(v) * M^T. - /// if transM == kTrans, then - /// this <-- beta*this + alpha * M^T * diag(v) * M. - void AddMat2Vec(const Real alpha, const MatrixBase &M, - MatrixTransposeType transM, const VectorBase &v, - const Real beta = 0.0); - - - /// Floors this symmetric matrix to the matrix - /// alpha * Floor, where the matrix Floor is positive - /// definite. - /// It is floored in the sense that after flooring, - /// x^T (*this) x >= x^T (alpha*Floor) x. - /// This is accomplished using an Svd. It will crash - /// if Floor is not positive definite. Returns the number of - /// elements that were floored. - int ApplyFloor(const SpMatrix &Floor, Real alpha = 1.0, - bool verbose = false); - - /// Floor: Given a positive semidefinite matrix, floors the eigenvalues - /// to the specified quantity. A previous version of this function had - /// a tolerance which is now no longer needed since we have code to - /// do the symmetric eigenvalue decomposition and no longer use the SVD - /// code for that purose. - int ApplyFloor(Real floor); - - bool IsDiagonal(Real cutoff = 1.0e-05) const; - bool IsUnit(Real cutoff = 1.0e-05) const; - bool IsZero(Real cutoff = 1.0e-05) const; - bool IsTridiagonal(Real cutoff = 1.0e-05) const; - - /// sqrt of sum of square elements. - Real FrobeniusNorm() const; - - /// Returns true if ((*this)-other).FrobeniusNorm() <= - /// tol*(*this).FrobeniusNorma() - bool ApproxEqual(const SpMatrix &other, float tol = 0.01) const; - - // LimitCond: - // Limits the condition of symmetric positive semidefinite matrix to - // a specified value - // by flooring all eigenvalues to a positive number which is some multiple - // of the largest one (or zero if there are no positive eigenvalues). - // Takes the condition number we are willing to accept, and floors - // eigenvalues to the largest eigenvalue divided by this. - // Returns #eigs floored or already equal to the floor. - // Throws exception if input is not positive definite. - // returns #floored. - MatrixIndexT LimitCond(Real maxCond = 1.0e+5, bool invert = false); - - // as LimitCond but all done in double precision. // returns #floored. - MatrixIndexT LimitCondDouble(Real maxCond = 1.0e+5, bool invert = false) { - SpMatrix dmat(*this); - MatrixIndexT ans = dmat.LimitCond(maxCond, invert); - (*this).CopyFromSp(dmat); - return ans; - } - Real Trace() const; - - /// Tridiagonalize the matrix with an orthogonal transformation. If - /// *this starts as S, produce T (and Q, if non-NULL) such that - /// T = Q A Q^T, i.e. S = Q^T T Q. Caution: this is the other way - /// round from most authors (it's more efficient in row-major indexing). - void Tridiagonalize(MatrixBase *Q); - - /// The symmetric QR algorithm. This will mostly be useful in internal code. - /// Typically, you will call this after Tridiagonalize(), on the same object. - /// When called, *this (call it A at this point) must be tridiagonal; at exit, - /// *this will be a diagonal matrix D that is similar to A via orthogonal - /// transformations. This algorithm right-multiplies Q by orthogonal - /// transformations. It turns *this from a tridiagonal into a diagonal matrix - /// while maintaining that (Q *this Q^T) has the same value at entry and exit. - /// At entry Q should probably be either NULL or orthogonal, but we don't check - /// this. - void Qr(MatrixBase *Q); - - private: - void EigInternal(VectorBase *s, MatrixBase *P, - Real tolerance, int recurse) const; -}; - -/// @} end of "addtogroup matrix_group" - -/// \addtogroup matrix_funcs_scalar -/// @{ - - -/// Returns tr(A B). -float TraceSpSp(const SpMatrix &A, const SpMatrix &B); -double TraceSpSp(const SpMatrix &A, const SpMatrix &B); - - -template -inline bool ApproxEqual(const SpMatrix &A, - const SpMatrix &B, Real tol = 0.01) { - return A.ApproxEqual(B, tol); -} - -template -inline void AssertEqual(const SpMatrix &A, - const SpMatrix &B, Real tol = 0.01) { - KALDI_ASSERT(ApproxEqual(A, B, tol)); -} - - - -/// Returns tr(A B). -template -Real TraceSpSp(const SpMatrix &A, const SpMatrix &B); - - - -// TraceSpSpLower is the same as Trace(A B) except the lower-diagonal elements -// are counted only once not twice as they should be. It is useful in certain -// optimizations. -template -Real TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); - - -/// Returns tr(A B). -/// No option to transpose B because would make no difference. -template -Real TraceSpMat(const SpMatrix &A, const MatrixBase &B); - -/// Returns tr(A B C) -/// (A and C may be transposed as specified by transA and transC). -template -Real TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC); - -/// Returns tr (A B C D) -/// (A and C may be transposed as specified by transA and transB). -template -Real TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, - const SpMatrix &B, const MatrixBase &C, - MatrixTransposeType transC, const SpMatrix &D); - -/** Computes v1^T * M * v2. Not as efficient as it could be where v1 == v2 - * (but no suitable blas routines available). - */ - -/// Returns \f$ v_1^T M v_2 \f$ -/// Not as efficient as it could be where v1 == v2. -template -Real VecSpVec(const VectorBase &v1, const SpMatrix &M, - const VectorBase &v2); - - -/// @} \addtogroup matrix_funcs_scalar - -/// \addtogroup matrix_funcs_misc -/// @{ - - -/// This class describes the options for maximizing various quadratic objective -/// functions. It's mostly as described in the SGMM paper "the subspace -/// Gaussian mixture model -- a structured model for speech recognition", but -/// the diagonal_precondition option is newly added, to handle problems where -/// different dimensions have very different scaling (we recommend to use the -/// option but it's set false for back compatibility). -struct SolverOptions { - BaseFloat K; // maximum condition number - BaseFloat eps; - std::string name; - bool optimize_delta; - bool diagonal_precondition; - bool print_debug_output; - explicit SolverOptions(const std::string &name): - K(1.0e+4), eps(1.0e-40), name(name), - optimize_delta(true), diagonal_precondition(false), - print_debug_output(true) { } - SolverOptions(): K(1.0e+4), eps(1.0e-40), name("[unknown]"), - optimize_delta(true), diagonal_precondition(false), - print_debug_output(true) { } - void Check() const; -}; - - -/// Maximizes the auxiliary function -/// \f[ Q(x) = x.g - 0.5 x^T H x \f] -/// using a numerically stable method. Like a numerically stable version of -/// \f$ x := Q^{-1} g. \f$ -/// Assumes H positive semidefinite. -/// Returns the objective-function change. - -template -Real SolveQuadraticProblem(const SpMatrix &H, - const VectorBase &g, - const SolverOptions &opts, - VectorBase *x); - - - -/// Maximizes the auxiliary function : -/// \f[ Q(x) = tr(M^T P Y) - 0.5 tr(P M Q M^T) \f] -/// Like a numerically stable version of \f$ M := Y Q^{-1} \f$. -/// Assumes Q and P positive semidefinite, and matrix dimensions match -/// enough to make expressions meaningful. -/// This is mostly as described in the SGMM paper "the subspace Gaussian mixture -/// model -- a structured model for speech recognition", but the -/// diagonal_precondition option is newly added, to handle problems -/// where different dimensions have very different scaling (we recommend to use -/// the option but it's set false for back compatibility). -template -Real SolveQuadraticMatrixProblem(const SpMatrix &Q, - const MatrixBase &Y, - const SpMatrix &P, - const SolverOptions &opts, - MatrixBase *M); - -/// Maximizes the auxiliary function : -/// \f[ Q(M) = tr(M^T G) -0.5 tr(P_1 M Q_1 M^T) -0.5 tr(P_2 M Q_2 M^T). \f] -/// Encountered in matrix update with a prior. We also apply a limit on the -/// condition but it should be less frequently necessary, and can be set larger. -template -Real SolveDoubleQuadraticMatrixProblem(const MatrixBase &G, - const SpMatrix &P1, - const SpMatrix &P2, - const SpMatrix &Q1, - const SpMatrix &Q2, - const SolverOptions &opts, - MatrixBase *M); - - -/// @} End of "addtogroup matrix_funcs_misc" - -} // namespace kaldi - - -// Including the implementation (now actually just includes some -// template specializations). -#include "matrix/sp-matrix-inl.h" - - -#endif // KALDI_MATRIX_SP_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/sparse-matrix.cc b/speechx/speechx/kaldi/matrix/sparse-matrix.cc deleted file mode 100644 index 68a61e17..00000000 --- a/speechx/speechx/kaldi/matrix/sparse-matrix.cc +++ /dev/null @@ -1,1296 +0,0 @@ -// matrix/sparse-matrix.cc - -// Copyright 2015 Johns Hopkins University (author: Daniel Povey) -// 2015 Guoguo Chen -// 2017 Shiyin Kang - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "matrix/sparse-matrix.h" -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - -template -std::pair* SparseVector::Data() { - if (pairs_.empty()) - return NULL; - else - return &(pairs_[0]); -} - -template -const std::pair* SparseVector::Data() const { - if (pairs_.empty()) - return NULL; - else - return &(pairs_[0]); -} - -template -Real SparseVector::Sum() const { - Real sum = 0; - for (int32 i = 0; i < pairs_.size(); ++i) { - sum += pairs_[i].second; - } - return sum; -} - -template -void SparseVector::Scale(Real alpha) { - for (int32 i = 0; i < pairs_.size(); ++i) - pairs_[i].second *= alpha; -} - -template -template -void SparseVector::CopyElementsToVec(VectorBase *vec) const { - KALDI_ASSERT(vec->Dim() == this->dim_); - vec->SetZero(); - OtherReal *other_data = vec->Data(); - typename std::vector >::const_iterator - iter = pairs_.begin(), end = pairs_.end(); - for (; iter != end; ++iter) - other_data[iter->first] = iter->second; -} -template -void SparseVector::CopyElementsToVec(VectorBase *vec) const; -template -void SparseVector::CopyElementsToVec(VectorBase *vec) const; -template -void SparseVector::CopyElementsToVec(VectorBase *vec) const; -template -void SparseVector::CopyElementsToVec(VectorBase *vec) const; - -template -template -void SparseVector::AddToVec(Real alpha, - VectorBase *vec) const { - KALDI_ASSERT(vec->Dim() == dim_); - OtherReal *other_data = vec->Data(); - typename std::vector >::const_iterator - iter = pairs_.begin(), end = pairs_.end(); - if (alpha == 1.0) { // treat alpha==1.0 case specially. - for (; iter != end; ++iter) - other_data[iter->first] += iter->second; - } else { - for (; iter != end; ++iter) - other_data[iter->first] += alpha * iter->second; - } -} - -template -void SparseVector::AddToVec(float alpha, VectorBase *vec) const; -template -void SparseVector::AddToVec(float alpha, VectorBase *vec) const; -template -void SparseVector::AddToVec(double alpha, VectorBase *vec) const; -template -void SparseVector::AddToVec(double alpha, - VectorBase *vec) const; - -template -template -void SparseVector::CopyFromSvec(const SparseVector &other) { - dim_ = other.Dim(); - pairs_.clear(); - if (dim_ == 0) return; - for (int32 i = 0; i < other.NumElements(); ++i) { - pairs_.push_back(std::make_pair( - other.GetElement(i).first, - static_cast(other.GetElement(i).second))); - } -} -template -void SparseVector::CopyFromSvec(const SparseVector &svec); -template -void SparseVector::CopyFromSvec(const SparseVector &svec); -template -void SparseVector::CopyFromSvec(const SparseVector &svec); -template -void SparseVector::CopyFromSvec(const SparseVector &svec); - - -template -SparseVector& SparseVector::operator = ( - const SparseVector &other) { - this->CopyFromSvec(other); - dim_ = other.dim_; - pairs_ = other.pairs_; - return *this; -} - -template -void SparseVector::Swap(SparseVector *other) { - pairs_.swap(other->pairs_); - std::swap(dim_, other->dim_); -} - -template -void SparseVector::Write(std::ostream &os, bool binary) const { - if (binary) { - WriteToken(os, binary, "SV"); - WriteBasicType(os, binary, dim_); - MatrixIndexT num_elems = pairs_.size(); - WriteBasicType(os, binary, num_elems); - typename std::vector >::const_iterator - iter = pairs_.begin(), end = pairs_.end(); - for (; iter != end; ++iter) { - WriteBasicType(os, binary, iter->first); - WriteBasicType(os, binary, iter->second); - } - } else { - // In text-mode, use a human-friendly, script-friendly format; - // format is "dim=5 [ 0 0.2 3 0.9 ] " - os << "dim=" << dim_ << " [ "; - typename std::vector >::const_iterator - iter = pairs_.begin(), end = pairs_.end(); - for (; iter != end; ++iter) - os << iter->first << ' ' << iter->second << ' '; - os << "] "; - } -} - - -template -void SparseVector::Read(std::istream &is, bool binary) { - if (binary) { - ExpectToken(is, binary, "SV"); - ReadBasicType(is, binary, &dim_); - KALDI_ASSERT(dim_ >= 0); - int32 num_elems; - ReadBasicType(is, binary, &num_elems); - KALDI_ASSERT(num_elems >= 0 && num_elems <= dim_); - pairs_.resize(num_elems); - typename std::vector >::iterator - iter = pairs_.begin(), end = pairs_.end(); - for (; iter != end; ++iter) { - ReadBasicType(is, binary, &(iter->first)); - ReadBasicType(is, binary, &(iter->second)); - } - } else { - // In text-mode, format is "dim=5 [ 0 0.2 3 0.9 ] - std::string str; - is >> str; - if (str.substr(0, 4) != "dim=") - KALDI_ERR << "Reading sparse vector, expected 'dim=xxx', got " << str; - std::string dim_str = str.substr(4, std::string::npos); - std::istringstream dim_istr(dim_str); - int32 dim = -1; - dim_istr >> dim; - if (dim < 0 || dim_istr.fail()) { - KALDI_ERR << "Reading sparse vector, expected 'dim=[int]', got " << str; - } - dim_ = dim; - is >> std::ws; - is >> str; - if (str != "[") - KALDI_ERR << "Reading sparse vector, expected '[', got " << str; - pairs_.clear(); - while (1) { - is >> std::ws; - if (is.peek() == ']') { - is.get(); - break; - } - MatrixIndexT i; - BaseFloat p; - is >> i >> p; - if (is.fail()) - KALDI_ERR << "Error reading sparse vector, expecting numbers."; - KALDI_ASSERT(i >= 0 && i < dim - && (pairs_.empty() || i > pairs_.back().first)); - pairs_.push_back(std::pair(i, p)); - } - } -} - - -namespace sparse_vector_utils { -template -struct CompareFirst { - inline bool operator() (const std::pair &p1, - const std::pair &p2) const { - return p1.first < p2.first; - } -}; -} - -template -SparseVector::SparseVector( - MatrixIndexT dim, const std::vector > &pairs): - dim_(dim), - pairs_(pairs) { - std::sort(pairs_.begin(), pairs_.end(), - sparse_vector_utils::CompareFirst()); - typename std::vector >::iterator - out = pairs_.begin(), in = out, end = pairs_.end(); - // special case: while there is nothing to be changed, skip over - // initial input (avoids unnecessary copying). - while (in + 1 < end && in[0].first != in[1].first && in[0].second != 0.0) { - in++; - out++; - } - while (in < end) { - // We reach this point only at the first element of - // each stretch of identical .first elements. - *out = *in; - ++in; - while (in < end && in->first == out->first) { - out->second += in->second; // this is the merge operation. - ++in; - } - if (out->second != Real(0.0)) // Don't keep zero elements. - out++; - } - pairs_.erase(out, end); - if (!pairs_.empty()) { - // range check. - KALDI_ASSERT(pairs_.front().first >= 0 && pairs_.back().first < dim_); - } -} - -template -void SparseVector::SetRandn(BaseFloat zero_prob) { - pairs_.clear(); - KALDI_ASSERT(zero_prob >= 0 && zero_prob <= 1.0); - for (MatrixIndexT i = 0; i < dim_; i++) - if (WithProb(1.0 - zero_prob)) - pairs_.push_back(std::pair(i, RandGauss())); -} - -template -void SparseVector::Resize(MatrixIndexT dim, - MatrixResizeType resize_type) { - if (resize_type != kCopyData || dim == 0) - pairs_.clear(); - KALDI_ASSERT(dim >= 0); - if (dim < dim_ && resize_type == kCopyData) - while (!pairs_.empty() && pairs_.back().first >= dim) - pairs_.pop_back(); - dim_ = dim; -} - -template -MatrixIndexT SparseMatrix::NumRows() const { - return rows_.size(); -} - -template -MatrixIndexT SparseMatrix::NumCols() const { - if (rows_.empty()) - return 0.0; - else - return rows_[0].Dim(); -} - -template -MatrixIndexT SparseMatrix::NumElements() const { - int32 num_elements = 0; - for (int32 i = 0; i < rows_.size(); ++i) { - num_elements += rows_[i].NumElements(); - } - return num_elements; -} - -template -SparseVector* SparseMatrix::Data() { - if (rows_.empty()) - return NULL; - else - return rows_.data(); -} - -template -const SparseVector* SparseMatrix::Data() const { - if (rows_.empty()) - return NULL; - else - return rows_.data(); -} - -template -Real SparseMatrix::Sum() const { - Real sum = 0; - for (int32 i = 0; i < rows_.size(); ++i) { - sum += rows_[i].Sum(); - } - return sum; -} - -template -Real SparseMatrix::FrobeniusNorm() const { - Real squared_sum = 0; - for (int32 i = 0; i < rows_.size(); ++i) { - const std::pair *row_data = rows_[i].Data(); - for (int32 j = 0; j < rows_[i].NumElements(); ++j) { - squared_sum += row_data[j].second * row_data[j].second; - } - } - return std::sqrt(squared_sum); -} - -template -template -void SparseMatrix::CopyToMat(MatrixBase *other, - MatrixTransposeType trans) const { - if (trans == kNoTrans) { - MatrixIndexT num_rows = rows_.size(); - KALDI_ASSERT(other->NumRows() == num_rows); - for (MatrixIndexT i = 0; i < num_rows; i++) { - SubVector vec(*other, i); - rows_[i].CopyElementsToVec(&vec); - } - } else { - OtherReal *other_col_data = other->Data(); - MatrixIndexT other_stride = other->Stride(), - num_rows = NumRows(), num_cols = NumCols(); - KALDI_ASSERT(num_rows == other->NumCols() && num_cols == other->NumRows()); - other->SetZero(); - for (MatrixIndexT row = 0; row < num_rows; row++, other_col_data++) { - const SparseVector &svec = rows_[row]; - MatrixIndexT num_elems = svec.NumElements(); - const std::pair *sdata = svec.Data(); - for (MatrixIndexT e = 0; e < num_elems; e++) - other_col_data[sdata[e].first * other_stride] = sdata[e].second; - } - } -} - -template -void SparseMatrix::CopyToMat(MatrixBase *other, - MatrixTransposeType trans) const; -template -void SparseMatrix::CopyToMat(MatrixBase *other, - MatrixTransposeType trans) const; -template -void SparseMatrix::CopyToMat(MatrixBase *other, - MatrixTransposeType trans) const; -template -void SparseMatrix::CopyToMat(MatrixBase *other, - MatrixTransposeType trans) const; - -template -void SparseMatrix::CopyElementsToVec(VectorBase *other) const { - KALDI_ASSERT(other->Dim() == NumElements()); - Real *dst_data = other->Data(); - int32 dst_index = 0; - for (int32 i = 0; i < rows_.size(); ++i) { - for (int32 j = 0; j < rows_[i].NumElements(); ++j) { - dst_data[dst_index] = - static_cast(rows_[i].GetElement(j).second); - dst_index++; - } - } -} - -template -template -void SparseMatrix::CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans) { - if (trans == kNoTrans) { - rows_.resize(other.NumRows()); - if (rows_.size() == 0) - return; - for (int32 r = 0; r < rows_.size(); ++r) { - rows_[r].CopyFromSvec(other.Row(r)); - } - } else { - std::vector > > pairs( - other.NumCols()); - for (MatrixIndexT i = 0; i < other.NumRows(); ++i) { - for (int id = 0; id < other.Row(i).NumElements(); ++id) { - MatrixIndexT j = other.Row(i).GetElement(id).first; - Real v = static_cast(other.Row(i).GetElement(id).second); - pairs[j].push_back( { i, v }); - } - } - SparseMatrix temp(other.NumRows(), pairs); - Swap(&temp); - } -} -template -void SparseMatrix::CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans); -template -void SparseMatrix::CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans); -template -void SparseMatrix::CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans); -template -void SparseMatrix::CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans); - -template -void SparseMatrix::Write(std::ostream &os, bool binary) const { - if (binary) { - // Note: we can use the same marker for float and double SparseMatrix, - // because internally we use WriteBasicType and ReadBasicType to read the - // floats and doubles, and this will automatically take care of type - // conversion. - WriteToken(os, binary, "SM"); - int32 num_rows = rows_.size(); - WriteBasicType(os, binary, num_rows); - for (int32 row = 0; row < num_rows; row++) - rows_[row].Write(os, binary); - } else { - // The format is "rows=10 dim=20 [ 1 0.4 9 1.2 ] dim=20 [ 3 1.7 19 0.6 ] .. - // not 100% efficient, but easy to work with, and we can re-use the - // read/write code from SparseVector. - int32 num_rows = rows_.size(); - os << "rows=" << num_rows << " "; - for (int32 row = 0; row < num_rows; row++) - rows_[row].Write(os, binary); - os << "\n"; // Might make it a little more readable. - } -} - -template -void SparseMatrix::Read(std::istream &is, bool binary) { - if (binary) { - ExpectToken(is, binary, "SM"); - int32 num_rows; - ReadBasicType(is, binary, &num_rows); - KALDI_ASSERT(num_rows >= 0 && num_rows < 10000000); - rows_.resize(num_rows); - for (int32 row = 0; row < num_rows; row++) - rows_[row].Read(is, binary); - } else { - std::string str; - is >> str; - if (str.substr(0, 5) != "rows=") - KALDI_ERR << "Reading sparse matrix, expected 'rows=xxx', got " << str; - std::string rows_str = str.substr(5, std::string::npos); - std::istringstream rows_istr(rows_str); - int32 num_rows = -1; - rows_istr >> num_rows; - if (num_rows < 0 || rows_istr.fail()) { - KALDI_ERR << "Reading sparse vector, expected 'rows=[int]', got " << str; - } - rows_.resize(num_rows); - for (int32 row = 0; row < num_rows; row++) - rows_[row].Read(is, binary); - } -} - - -template -void SparseMatrix::AddToMat(BaseFloat alpha, - MatrixBase *other, - MatrixTransposeType trans) const { - if (trans == kNoTrans) { - MatrixIndexT num_rows = rows_.size(); - KALDI_ASSERT(other->NumRows() == num_rows); - for (MatrixIndexT i = 0; i < num_rows; i++) { - SubVector vec(*other, i); - rows_[i].AddToVec(alpha, &vec); - } - } else { - Real *other_col_data = other->Data(); - MatrixIndexT other_stride = other->Stride(), - num_rows = NumRows(), num_cols = NumCols(); - KALDI_ASSERT(num_rows == other->NumCols() && num_cols == other->NumRows()); - for (MatrixIndexT row = 0; row < num_rows; row++, other_col_data++) { - const SparseVector &svec = rows_[row]; - MatrixIndexT num_elems = svec.NumElements(); - const std::pair *sdata = svec.Data(); - for (MatrixIndexT e = 0; e < num_elems; e++) - other_col_data[sdata[e].first * other_stride] += - alpha * sdata[e].second; - } - } -} - -template -Real VecSvec(const VectorBase &vec, - const SparseVector &svec) { - KALDI_ASSERT(vec.Dim() == svec.Dim()); - MatrixIndexT n = svec.NumElements(); - const std::pair *sdata = svec.Data(); - const Real *data = vec.Data(); - Real ans = 0.0; - for (MatrixIndexT i = 0; i < n; i++) - ans += data[sdata[i].first] * sdata[i].second; - return ans; -} - -template -float VecSvec(const VectorBase &vec, - const SparseVector &svec); -template -double VecSvec(const VectorBase &vec, - const SparseVector &svec); - -template -const SparseVector &SparseMatrix::Row(MatrixIndexT r) const { - KALDI_ASSERT(static_cast(r) < rows_.size()); - return rows_[r]; -} - -template -void SparseMatrix::SetRow(int32 r, const SparseVector &vec) { - KALDI_ASSERT(static_cast(r) < rows_.size() && - vec.Dim() == rows_[0].Dim()); - rows_[r] = vec; -} - - -template -void SparseMatrix::SelectRows(const std::vector &row_indexes, - const SparseMatrix &smat_other) { - Resize(row_indexes.size(), smat_other.NumCols()); - for (int i = 0; i < row_indexes.size(); ++i) { - SetRow(i, smat_other.Row(row_indexes[i])); - } -} - -template -SparseMatrix::SparseMatrix(const std::vector &indexes, int32 dim, - MatrixTransposeType trans) { - const std::vector& idx = indexes; - std::vector > > pair(idx.size()); - for (int i = 0; i < idx.size(); ++i) { - if (idx[i] >= 0) { - pair[i].push_back( { idx[i], Real(1) }); - } - } - SparseMatrix smat_cpu(dim, pair); - if (trans == kNoTrans) { - this->Swap(&smat_cpu); - } else { - SparseMatrix tmp(smat_cpu, kTrans); - this->Swap(&tmp); - } -} - -template -SparseMatrix::SparseMatrix(const std::vector &indexes, - const VectorBase &weights, int32 dim, - MatrixTransposeType trans) { - const std::vector& idx = indexes; - const VectorBase& w = weights; - std::vector > > pair(idx.size()); - for (int i = 0; i < idx.size(); ++i) { - if (idx[i] >= 0) { - pair[i].push_back( { idx[i], w(i) }); - } - } - SparseMatrix smat_cpu(dim, pair); - if (trans == kNoTrans) { - this->Swap(&smat_cpu); - } else { - SparseMatrix tmp(smat_cpu, kTrans); - this->Swap(&tmp); - } -} - -template -SparseMatrix& SparseMatrix::operator = ( - const SparseMatrix &other) { - rows_ = other.rows_; - return *this; -} - -template -void SparseMatrix::Swap(SparseMatrix *other) { - rows_.swap(other->rows_); -} - -template -SparseMatrix::SparseMatrix( - MatrixIndexT dim, - const std::vector > > &pairs): - rows_(pairs.size()) { - MatrixIndexT num_rows = pairs.size(); - for (MatrixIndexT row = 0; row < num_rows; row++) { - SparseVector svec(dim, pairs[row]); - rows_[row].Swap(&svec); - } -} - -template -void SparseMatrix::SetRandn(BaseFloat zero_prob) { - MatrixIndexT num_rows = rows_.size(); - for (MatrixIndexT row = 0; row < num_rows; row++) - rows_[row].SetRandn(zero_prob); -} - -template -void SparseMatrix::Resize(MatrixIndexT num_rows, - MatrixIndexT num_cols, - MatrixResizeType resize_type) { - KALDI_ASSERT(num_rows >= 0 && num_cols >= 0); - if (resize_type == kSetZero || resize_type == kUndefined) { - rows_.clear(); - Resize(num_rows, num_cols, kCopyData); - } else { - // Assume resize_type == kCopyData from here. - int32 old_num_rows = rows_.size(), old_num_cols = NumCols(); - SparseVector initializer(num_cols); - rows_.resize(num_rows, initializer); - if (num_cols != old_num_cols) - for (int32 row = 0; row < old_num_rows; row++) - rows_[row].Resize(num_cols, kCopyData); - } -} - -template -void SparseMatrix::AppendSparseMatrixRows( - std::vector > *inputs) { - rows_.clear(); - size_t num_rows = 0; - typename std::vector >::iterator - input_iter = inputs->begin(), - input_end = inputs->end(); - for (; input_iter != input_end; ++input_iter) - num_rows += input_iter->rows_.size(); - rows_.resize(num_rows); - typename std::vector >::iterator - row_iter = rows_.begin(), - row_end = rows_.end(); - for (input_iter = inputs->begin(); input_iter != input_end; ++input_iter) { - typename std::vector >::iterator - input_row_iter = input_iter->rows_.begin(), - input_row_end = input_iter->rows_.end(); - for (; input_row_iter != input_row_end; ++input_row_iter, ++row_iter) - row_iter->Swap(&(*input_row_iter)); - } - KALDI_ASSERT(row_iter == row_end); - int32 num_cols = NumCols(); - for (row_iter = rows_.begin(); row_iter != row_end; ++row_iter) { - if (row_iter->Dim() != num_cols) - KALDI_ERR << "Appending rows with inconsistent dimensions, " - << row_iter->Dim() << " vs. " << num_cols; - } - inputs->clear(); -} - -template -void SparseMatrix::Scale(Real alpha) { - MatrixIndexT num_rows = rows_.size(); - for (MatrixIndexT row = 0; row < num_rows; row++) - rows_[row].Scale(alpha); -} - -template -SparseMatrix::SparseMatrix(const MatrixBase &mat) { - MatrixIndexT num_rows = mat.NumRows(); - rows_.resize(num_rows); - for (int32 row = 0; row < num_rows; row++) { - SparseVector this_row(mat.Row(row)); - rows_[row].Swap(&this_row); - } -} - -template -Real TraceMatSmat(const MatrixBase &A, - const SparseMatrix &B, - MatrixTransposeType trans) { - Real sum = 0.0; - if (trans == kTrans) { - MatrixIndexT num_rows = A.NumRows(); - KALDI_ASSERT(B.NumRows() == num_rows); - for (MatrixIndexT r = 0; r < num_rows; r++) - sum += VecSvec(A.Row(r), B.Row(r)); - } else { - const Real *A_col_data = A.Data(); - MatrixIndexT Astride = A.Stride(), Acols = A.NumCols(), Arows = A.NumRows(); - KALDI_ASSERT(Arows == B.NumCols() && Acols == B.NumRows()); - sum = 0.0; - for (MatrixIndexT i = 0; i < Acols; i++, A_col_data++) { - Real col_sum = 0.0; - const SparseVector &svec = B.Row(i); - MatrixIndexT num_elems = svec.NumElements(); - const std::pair *sdata = svec.Data(); - for (MatrixIndexT e = 0; e < num_elems; e++) - col_sum += A_col_data[Astride * sdata[e].first] * sdata[e].second; - sum += col_sum; - } - } - return sum; -} - -template -float TraceMatSmat(const MatrixBase &A, - const SparseMatrix &B, - MatrixTransposeType trans); -template -double TraceMatSmat(const MatrixBase &A, - const SparseMatrix &B, - MatrixTransposeType trans); - -void GeneralMatrix::Clear() { - mat_.Resize(0, 0); - cmat_.Clear(); - smat_.Resize(0, 0); -} - -GeneralMatrix& GeneralMatrix::operator= (const MatrixBase &mat) { - Clear(); - mat_ = mat; - return *this; -} - -GeneralMatrix& GeneralMatrix::operator= (const CompressedMatrix &cmat) { - Clear(); - cmat_ = cmat; - return *this; -} - -GeneralMatrix& GeneralMatrix::operator= (const SparseMatrix &smat) { - Clear(); - smat_ = smat; - return *this; -} - -GeneralMatrix& GeneralMatrix::operator= (const GeneralMatrix &gmat) { - mat_ = gmat.mat_; - smat_ = gmat.smat_; - cmat_ = gmat.cmat_; - return *this; -} - - -GeneralMatrixType GeneralMatrix::Type() const { - if (smat_.NumRows() != 0) - return kSparseMatrix; - else if (cmat_.NumRows() != 0) - return kCompressedMatrix; - else - return kFullMatrix; -} - -MatrixIndexT GeneralMatrix::NumRows() const { - MatrixIndexT r = smat_.NumRows(); - if (r != 0) - return r; - r = cmat_.NumRows(); - if (r != 0) - return r; - return mat_.NumRows(); -} - -MatrixIndexT GeneralMatrix::NumCols() const { - MatrixIndexT r = smat_.NumCols(); - if (r != 0) - return r; - r = cmat_.NumCols(); - if (r != 0) - return r; - return mat_.NumCols(); -} - - -void GeneralMatrix::Compress() { - if (mat_.NumRows() != 0) { - cmat_.CopyFromMat(mat_); - mat_.Resize(0, 0); - } -} - -void GeneralMatrix::Uncompress() { - if (cmat_.NumRows() != 0) { - mat_.Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); - cmat_.CopyToMat(&mat_); - cmat_.Clear(); - } -} - -void GeneralMatrix::GetMatrix(Matrix *mat) const { - if (mat_.NumRows() !=0) { - *mat = mat_; - } else if (cmat_.NumRows() != 0) { - mat->Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); - cmat_.CopyToMat(mat); - } else if (smat_.NumRows() != 0) { - mat->Resize(smat_.NumRows(), smat_.NumCols(), kUndefined); - smat_.CopyToMat(mat); - } else { - mat->Resize(0, 0); - } -} - -void GeneralMatrix::CopyToMat(MatrixBase *mat, - MatrixTransposeType trans) const { - if (mat_.NumRows() !=0) { - mat->CopyFromMat(mat_, trans); - } else if (cmat_.NumRows() != 0) { - cmat_.CopyToMat(mat, trans); - } else if (smat_.NumRows() != 0) { - smat_.CopyToMat(mat, trans); - } else { - KALDI_ASSERT(mat->NumRows() == 0); - } -} - -void GeneralMatrix::Scale(BaseFloat alpha) { - if (mat_.NumRows() != 0) { - mat_.Scale(alpha); - } else if (cmat_.NumRows() != 0) { - cmat_.Scale(alpha); - } else if (smat_.NumRows() != 0) { - smat_.Scale(alpha); - } - -} -const SparseMatrix& GeneralMatrix::GetSparseMatrix() const { - if (mat_.NumRows() != 0 || cmat_.NumRows() != 0) - KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; - return smat_; -} - -void GeneralMatrix::SwapSparseMatrix(SparseMatrix *smat) { - if (mat_.NumRows() != 0 || cmat_.NumRows() != 0) - KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; - smat->Swap(&smat_); -} - -void GeneralMatrix::SwapCompressedMatrix(CompressedMatrix *cmat) { - if (mat_.NumRows() != 0 || smat_.NumRows() != 0) - KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; - cmat->Swap(&cmat_); -} - -const CompressedMatrix &GeneralMatrix::GetCompressedMatrix() const { - if (mat_.NumRows() != 0 || smat_.NumRows() != 0) - KALDI_ERR << "GetCompressedMatrix called on GeneralMatrix of wrong type."; - return cmat_; -} - -const Matrix &GeneralMatrix::GetFullMatrix() const { - if (smat_.NumRows() != 0 || cmat_.NumRows() != 0) - KALDI_ERR << "GetFullMatrix called on GeneralMatrix of wrong type."; - return mat_; -} - - -void GeneralMatrix::SwapFullMatrix(Matrix *mat) { - if (cmat_.NumRows() != 0 || smat_.NumRows() != 0) - KALDI_ERR << "SwapMatrix called on GeneralMatrix of wrong type."; - mat->Swap(&mat_); -} - -void GeneralMatrix::Write(std::ostream &os, bool binary) const { - if (smat_.NumRows() != 0) { - smat_.Write(os, binary); - } else if (cmat_.NumRows() != 0) { - cmat_.Write(os, binary); - } else { - mat_.Write(os, binary); - } -} - -void GeneralMatrix::Read(std::istream &is, bool binary) { - Clear(); - if (binary) { - int peekval = is.peek(); - if (peekval == 'C') { - // Token CM for compressed matrix - cmat_.Read(is, binary); - } else if (peekval == 'S') { - // Token SM for sparse matrix - smat_.Read(is, binary); - } else { - mat_.Read(is, binary); - } - } else { - // note: in text mode we will only ever read regular - // or sparse matrices, because the compressed-matrix format just - // gets written as a regular matrix in text mode. - is >> std::ws; // Eat up white space. - int peekval = is.peek(); - if (peekval == 'r') { // sparse format starts rows=[int]. - smat_.Read(is, binary); - } else { - mat_.Read(is, binary); - } - } -} - - -void AppendGeneralMatrixRows(const std::vector &src, - GeneralMatrix *mat) { - mat->Clear(); - int32 size = src.size(); - if (size == 0) - return; - bool all_sparse = true; - for (int32 i = 0; i < size; i++) { - if (src[i]->Type() != kSparseMatrix && src[i]->NumRows() != 0) { - all_sparse = false; - break; - } - } - if (all_sparse) { - std::vector > sparse_mats(size); - for (int32 i = 0; i < size; i++) - sparse_mats[i] = src[i]->GetSparseMatrix(); - SparseMatrix appended_mat; - appended_mat.AppendSparseMatrixRows(&sparse_mats); - mat->SwapSparseMatrix(&appended_mat); - } else { - int32 tot_rows = 0, num_cols = -1; - for (int32 i = 0; i < size; i++) { - const GeneralMatrix &src_mat = *(src[i]); - int32 src_rows = src_mat.NumRows(), src_cols = src_mat.NumCols(); - if (src_rows != 0) { - tot_rows += src_rows; - if (num_cols == -1) num_cols = src_cols; - else if (num_cols != src_cols) - KALDI_ERR << "Appending rows of matrices with inconsistent num-cols: " - << num_cols << " vs. " << src_cols; - } - } - Matrix appended_mat(tot_rows, num_cols, kUndefined); - int32 row_offset = 0; - for (int32 i = 0; i < size; i++) { - const GeneralMatrix &src_mat = *(src[i]); - int32 src_rows = src_mat.NumRows(); - if (src_rows != 0) { - SubMatrix dest_submat(appended_mat, row_offset, src_rows, - 0, num_cols); - src_mat.CopyToMat(&dest_submat); - row_offset += src_rows; - } - } - KALDI_ASSERT(row_offset == tot_rows); - mat->SwapFullMatrix(&appended_mat); - } -} - -void FilterCompressedMatrixRows(const CompressedMatrix &in, - const std::vector &keep_rows, - Matrix *out) { - KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); - int32 num_kept_rows = 0; - std::vector::const_iterator iter = keep_rows.begin(), - end = keep_rows.end(); - for (; iter != end; ++iter) - if (*iter) - num_kept_rows++; - if (num_kept_rows == 0) - KALDI_ERR << "No kept rows"; - if (num_kept_rows == static_cast(keep_rows.size())) { - out->Resize(in.NumRows(), in.NumCols(), kUndefined); - in.CopyToMat(out); - return; - } - const BaseFloat heuristic = 0.33; - // should be > 0 and < 1.0. represents the performance hit we get from - // iterating row-wise versus column-wise in compressed-matrix uncompression. - - if (num_kept_rows > heuristic * in.NumRows()) { - // if quite a few of the the rows are kept, it may be more efficient - // to uncompress the entire compressed matrix, since per-column operation - // is more efficient. - Matrix full_mat(in); - FilterMatrixRows(full_mat, keep_rows, out); - } else { - out->Resize(num_kept_rows, in.NumCols(), kUndefined); - - iter = keep_rows.begin(); - int32 out_row = 0; - for (int32 in_row = 0; iter != end; ++iter, ++in_row) { - if (*iter) { - SubVector dest(*out, out_row); - in.CopyRowToVec(in_row, &dest); - out_row++; - } - } - KALDI_ASSERT(out_row == num_kept_rows); - } -} - -template -void FilterMatrixRows(const Matrix &in, - const std::vector &keep_rows, - Matrix *out) { - KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); - int32 num_kept_rows = 0; - std::vector::const_iterator iter = keep_rows.begin(), - end = keep_rows.end(); - for (; iter != end; ++iter) - if (*iter) - num_kept_rows++; - if (num_kept_rows == 0) - KALDI_ERR << "No kept rows"; - if (num_kept_rows == static_cast(keep_rows.size())) { - *out = in; - return; - } - out->Resize(num_kept_rows, in.NumCols(), kUndefined); - iter = keep_rows.begin(); - int32 out_row = 0; - for (int32 in_row = 0; iter != end; ++iter, ++in_row) { - if (*iter) { - SubVector src(in, in_row); - SubVector dest(*out, out_row); - dest.CopyFromVec(src); - out_row++; - } - } - KALDI_ASSERT(out_row == num_kept_rows); -} - -template -void FilterMatrixRows(const Matrix &in, - const std::vector &keep_rows, - Matrix *out); -template -void FilterMatrixRows(const Matrix &in, - const std::vector &keep_rows, - Matrix *out); - -template -void FilterSparseMatrixRows(const SparseMatrix &in, - const std::vector &keep_rows, - SparseMatrix *out) { - KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); - int32 num_kept_rows = 0; - std::vector::const_iterator iter = keep_rows.begin(), - end = keep_rows.end(); - for (; iter != end; ++iter) - if (*iter) - num_kept_rows++; - if (num_kept_rows == 0) - KALDI_ERR << "No kept rows"; - if (num_kept_rows == static_cast(keep_rows.size())) { - *out = in; - return; - } - out->Resize(num_kept_rows, in.NumCols(), kUndefined); - iter = keep_rows.begin(); - int32 out_row = 0; - for (int32 in_row = 0; iter != end; ++iter, ++in_row) { - if (*iter) { - out->SetRow(out_row, in.Row(in_row)); - out_row++; - } - } - KALDI_ASSERT(out_row == num_kept_rows); -} - -template -void FilterSparseMatrixRows(const SparseMatrix &in, - const std::vector &keep_rows, - SparseMatrix *out); -template -void FilterSparseMatrixRows(const SparseMatrix &in, - const std::vector &keep_rows, - SparseMatrix *out); - - -void FilterGeneralMatrixRows(const GeneralMatrix &in, - const std::vector &keep_rows, - GeneralMatrix *out) { - out->Clear(); - KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); - int32 num_kept_rows = 0; - std::vector::const_iterator iter = keep_rows.begin(), - end = keep_rows.end(); - for (; iter != end; ++iter) - if (*iter) - num_kept_rows++; - if (num_kept_rows == 0) - KALDI_ERR << "No kept rows"; - if (num_kept_rows == static_cast(keep_rows.size())) { - *out = in; - return; - } - switch (in.Type()) { - case kCompressedMatrix: { - const CompressedMatrix &cmat = in.GetCompressedMatrix(); - Matrix full_mat; - FilterCompressedMatrixRows(cmat, keep_rows, &full_mat); - out->SwapFullMatrix(&full_mat); - return; - } - case kSparseMatrix: { - const SparseMatrix &smat = in.GetSparseMatrix(); - SparseMatrix smat_out; - FilterSparseMatrixRows(smat, keep_rows, &smat_out); - out->SwapSparseMatrix(&smat_out); - return; - } - case kFullMatrix: { - const Matrix &full_mat = in.GetFullMatrix(); - Matrix full_mat_out; - FilterMatrixRows(full_mat, keep_rows, &full_mat_out); - out->SwapFullMatrix(&full_mat_out); - return; - } - default: - KALDI_ERR << "Invalid general-matrix type."; - } -} - -void GeneralMatrix::AddToMat(BaseFloat alpha, MatrixBase *mat, - MatrixTransposeType trans) const { - switch (this->Type()) { - case kFullMatrix: { - mat->AddMat(alpha, mat_, trans); - break; - } - case kSparseMatrix: { - smat_.AddToMat(alpha, mat, trans); - break; - } - case kCompressedMatrix: { - Matrix temp_mat(cmat_); - mat->AddMat(alpha, temp_mat, trans); - break; - } - default: - KALDI_ERR << "Invalid general-matrix type."; - } -} - -template -Real SparseVector::Max(int32 *index_out) const { - KALDI_ASSERT(dim_ > 0 && pairs_.size() <= static_cast(dim_)); - Real ans = -std::numeric_limits::infinity(); - int32 index = 0; - typename std::vector >::const_iterator - iter = pairs_.begin(), end = pairs_.end(); - for (; iter != end; ++iter) { - if (iter->second > ans) { - ans = iter->second; - index = iter->first; - } - } - if (ans >= 0 || pairs_.size() == dim_) { - // ans >= 0 will be the normal case. - // if pairs_.size() == dim_ then we need to return - // even a negative answer as there are no spaces (hence no unlisted zeros). - *index_out = index; - return ans; - } - // all the stored elements are < 0, but there are unlisted - // elements -> pick the first unlisted element. - // Note that this class requires that the indexes are sorted - // and unique. - index = 0; // "index" will always be the next index, that - // we haven't seen listed yet. - iter = pairs_.begin(); - for (; iter != end; ++iter) { - if (iter->first > index) { // index "index" is not listed. - *index_out = index; - return 0.0; - } else { - // index is the next potential gap in the indexes. - index = iter->first + 1; - } - } - // we can reach here if either pairs_.empty(), or - // pairs_ is nonempty but contains a sequence (0, 1, 2,...). - if (!pairs_.empty()) - index = pairs_.back().first + 1; - // else leave index at zero - KALDI_ASSERT(index < dim_); - *index_out = index; - return 0.0; -} - -template -SparseVector::SparseVector(const VectorBase &vec) { - MatrixIndexT dim = vec.Dim(); - dim_ = dim; - if (dim == 0) - return; - const Real *ptr = vec.Data(); - for (MatrixIndexT i = 0; i < dim; i++) { - Real val = ptr[i]; - if (val != 0.0) - pairs_.push_back(std::pair(i,val)); - } -} - -void GeneralMatrix::Swap(GeneralMatrix *other) { - mat_.Swap(&(other->mat_)); - cmat_.Swap(&(other->cmat_)); - smat_.Swap(&(other->smat_)); -} - - -void ExtractRowRangeWithPadding( - const GeneralMatrix &in, - int32 row_offset, - int32 num_rows, - GeneralMatrix *out) { - // make sure 'out' is empty to start with. - Matrix empty_mat; - *out = empty_mat; - if (num_rows == 0) return; - switch (in.Type()) { - case kFullMatrix: { - const Matrix &mat_in = in.GetFullMatrix(); - int32 num_rows_in = mat_in.NumRows(), num_cols = mat_in.NumCols(); - KALDI_ASSERT(num_rows_in > 0); // we can't extract >0 rows from an empty - // matrix. - Matrix mat_out(num_rows, num_cols, kUndefined); - for (int32 row = 0; row < num_rows; row++) { - int32 row_in = row + row_offset; - if (row_in < 0) row_in = 0; - else if (row_in >= num_rows_in) row_in = num_rows_in - 1; - SubVector vec_in(mat_in, row_in), - vec_out(mat_out, row); - vec_out.CopyFromVec(vec_in); - } - out->SwapFullMatrix(&mat_out); - break; - } - case kSparseMatrix: { - const SparseMatrix &smat_in = in.GetSparseMatrix(); - int32 num_rows_in = smat_in.NumRows(), - num_cols = smat_in.NumCols(); - KALDI_ASSERT(num_rows_in > 0); // we can't extract >0 rows from an empty - // matrix. - SparseMatrix smat_out(num_rows, num_cols); - for (int32 row = 0; row < num_rows; row++) { - int32 row_in = row + row_offset; - if (row_in < 0) row_in = 0; - else if (row_in >= num_rows_in) row_in = num_rows_in - 1; - smat_out.SetRow(row, smat_in.Row(row_in)); - } - out->SwapSparseMatrix(&smat_out); - break; - } - case kCompressedMatrix: { - const CompressedMatrix &cmat_in = in.GetCompressedMatrix(); - bool allow_padding = true; - CompressedMatrix cmat_out(cmat_in, row_offset, num_rows, - 0, cmat_in.NumCols(), allow_padding); - out->SwapCompressedMatrix(&cmat_out); - break; - } - default: - KALDI_ERR << "Bad matrix type."; - } -} - - - -template class SparseVector; -template class SparseVector; -template class SparseMatrix; -template class SparseMatrix; - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sparse-matrix.h b/speechx/speechx/kaldi/matrix/sparse-matrix.h deleted file mode 100644 index 76f77f53..00000000 --- a/speechx/speechx/kaldi/matrix/sparse-matrix.h +++ /dev/null @@ -1,452 +0,0 @@ -// matrix/sparse-matrix.h - -// Copyright 2015 Johns Hopkins University (author: Daniel Povey) -// 2015 Guoguo Chen -// 2017 Shiyin Kang - - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#ifndef KALDI_MATRIX_SPARSE_MATRIX_H_ -#define KALDI_MATRIX_SPARSE_MATRIX_H_ 1 - -#include -#include - -#include "matrix/matrix-common.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/kaldi-vector.h" -#include "matrix/compressed-matrix.h" - -namespace kaldi { - - -/// \addtogroup matrix_group -/// @{ - -template -class SparseVector { - public: - MatrixIndexT Dim() const { return dim_; } - - Real Sum() const; - - template - void CopyElementsToVec(VectorBase *vec) const; - - // *vec += alpha * *this. - template - void AddToVec(Real alpha, - VectorBase *vec) const; - - template - void CopyFromSvec(const SparseVector &other); - - SparseVector &operator = (const SparseVector &other); - - SparseVector(const SparseVector &other) { *this = other; } - - void Swap(SparseVector *other); - - // Returns the maximum value in this row and outputs the index associated with - // it. This is not the index into the Data() pointer, it is the index into - // the vector it represents, i.e. the .first value in the pair. - // If this vector's Dim() is zero it is an error to call this function. - // If all the elements stored were negative and there underlying vector had - // zero indexes not listed in the elements, or if no elements are stored, it - // will return the first un-listed index, whose value (implicitly) is zero. - Real Max(int32 *index) const; - - /// Returns the number of nonzero elements. - MatrixIndexT NumElements() const { return pairs_.size(); } - - /// get an indexed element (0 <= i < NumElements()). - const std::pair &GetElement(MatrixIndexT i) const { - return pairs_[i]; - } - - // returns pointer to element data, or NULL if empty (use with NumElements()). - std::pair *Data(); - - // returns pointer to element data, or NULL if empty (use with NumElements()); - // const version - const std::pair *Data() const; - - /// Sets elements to zero with probability zero_prob, else normally - /// distributed. Useful in testing. - void SetRandn(BaseFloat zero_prob); - - SparseVector(): dim_(0) { } - - explicit SparseVector(MatrixIndexT dim): dim_(dim) { KALDI_ASSERT(dim >= 0); } - - // constructor from pairs; does not assume input pairs are sorted and uniq - SparseVector(MatrixIndexT dim, - const std::vector > &pairs); - - // constructor from a VectorBase that keeps only the nonzero elements of 'vec'. - explicit SparseVector(const VectorBase &vec); - - /// Resizes to this dimension. resize_type == kUndefined - /// behaves the same as kSetZero. - void Resize(MatrixIndexT dim, MatrixResizeType resize_type = kSetZero); - - void Write(std::ostream &os, bool binary) const; - - void Read(std::istream &os, bool binary); - - /// Scale all elements of sparse vector. - void Scale(Real alpha); - - private: - MatrixIndexT dim_; - // pairs of (row-index, value). Stored in sorted order with no duplicates. - // For now we use std::vector, but we could change this. - std::vector > pairs_; -}; - - -template -Real VecSvec(const VectorBase &vec, - const SparseVector &svec); - - - -template -class SparseMatrix { - public: - MatrixIndexT NumRows() const; - - MatrixIndexT NumCols() const; - - MatrixIndexT NumElements() const; - - Real Sum() const; - - Real FrobeniusNorm() const; - - - /// This constructor creates a SparseMatrix that just contains the nonzero - /// elements of 'mat'. - explicit SparseMatrix(const MatrixBase &mat); - - /// Copy to matrix. It must already have the correct size. - template - void CopyToMat(MatrixBase *other, - MatrixTransposeType t = kNoTrans) const; - - /// Copies the values of all the elements in SparseMatrix into a VectorBase - /// object. - void CopyElementsToVec(VectorBase *other) const; - - /// Copies data from another sparse matrix. - template - void CopyFromSmat(const SparseMatrix &other, - MatrixTransposeType trans = kNoTrans); - - /// Does *other = *other + alpha * *this. - void AddToMat(BaseFloat alpha, MatrixBase *other, - MatrixTransposeType t = kNoTrans) const; - - SparseMatrix &operator = (const SparseMatrix &other); - - SparseMatrix(const SparseMatrix &other, MatrixTransposeType trans = - kNoTrans) { - this->CopyFromSmat(other, trans); - } - - void Swap(SparseMatrix *other); - - // returns pointer to element data, or NULL if empty (use with NumElements()). - SparseVector *Data(); - - // returns pointer to element data, or NULL if empty (use with NumElements()); - // const version - const SparseVector *Data() const; - - // initializer from the type that elsewhere in Kaldi is referred to as type - // Posterior. indexed first by row-index; the pairs are (column-index, value), - // and the constructor does not require them to be sorted and uniq. - SparseMatrix( - int32 dim, - const std::vector > > &pairs); - - /// Sets up to a pseudo-randomly initialized matrix, with each element zero - /// with probability zero_prob and else normally distributed- mostly for - /// purposes of testing. - void SetRandn(BaseFloat zero_prob); - - void Write(std::ostream &os, bool binary) const; - - void Read(std::istream &os, bool binary); - - const SparseVector &Row(MatrixIndexT r) const; - - /// Sets row r to "vec"; makes sure it has the correct dimension. - void SetRow(int32 r, const SparseVector &vec); - - /// Select a subset of the rows of a SparseMatrix. - /// Sets *this to only the rows of 'smat_other' that are listed - /// in 'row_indexes'. - /// 'row_indexes' must satisfy 0 <= row_indexes[i] < smat_other.NumRows(). - void SelectRows(const std::vector &row_indexes, - const SparseMatrix &smat_other); - - - /// Sets *this to all the rows of *inputs appended together; this - /// function is destructive of the inputs. Requires, obviously, - /// that the inputs all have the same dimension (although some may be - /// empty). - void AppendSparseMatrixRows(std::vector > *inputs); - - SparseMatrix() { } - - SparseMatrix(int32 num_rows, int32 num_cols) { Resize(num_rows, num_cols); } - - /// Constructor from an array of indexes. - /// If trans == kNoTrans, construct a sparse matrix - /// with num-rows == indexes.Dim() and num-cols = 'dim'. - /// 'indexes' is expected to contain elements in the - /// range [0, dim - 1]. Each row 'i' of *this after - /// calling the constructor will contain a single - /// element at column-index indexes[i] with value 1.0. - /// - /// If trans == kTrans, the result will be the transpose - /// of the sparse matrix described above. - SparseMatrix(const std::vector &indexes, int32 dim, - MatrixTransposeType trans = kNoTrans); - - /// Constructor from an array of indexes and an array of - /// weights; requires indexes.Dim() == weights.Dim(). - /// If trans == kNoTrans, construct a sparse matrix - /// with num-rows == indexes.Dim() and num-cols = 'dim'. - /// 'indexes' is expected to contain elements in the - /// range [0, dim - 1]. Each row 'i' of *this after - /// calling the constructor will contain a single - /// element at column-index indexes[i] with value weights[i]. - /// If trans == kTrans, the result will be the transpose - /// of the sparse matrix described above. - SparseMatrix(const std::vector &indexes, - const VectorBase &weights, int32 dim, - MatrixTransposeType trans = kNoTrans); - - /// Resizes the matrix; analogous to Matrix::Resize(). resize_type == - /// kUndefined behaves the same as kSetZero. - void Resize(MatrixIndexT rows, MatrixIndexT cols, - MatrixResizeType resize_type = kSetZero); - - /// Scale all elements in sparse matrix. - void Scale(Real alpha); - - // Use the Matrix::CopyFromSmat() function to copy from this to Matrix. Also - // see Matrix::AddSmat(). There is not very extensive functionality for - // SparseMat just yet (e.g. no matrix multiply); we will add things as needed - // and as it seems necessary. - private: - // vector of SparseVectors, all of same dime (use an stl vector for now; this - // could change). - std::vector > rows_; -}; - - -template -Real TraceMatSmat(const MatrixBase &A, - const SparseMatrix &B, - MatrixTransposeType trans = kNoTrans); - - -enum GeneralMatrixType { - kFullMatrix, - kCompressedMatrix, - kSparseMatrix -}; - -/// This class is a wrapper that enables you to store a matrix -/// in one of three forms: either as a Matrix, or a CompressedMatrix, -/// or a SparseMatrix. It handles the I/O for you, i.e. you read -/// and write a single object type. It is useful for neural-net training -/// targets which might be sparse or not, and might be compressed or not. -class GeneralMatrix { - public: - /// Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or - /// kFullMatrix. If this matrix is empty, returns kFullMatrix. - GeneralMatrixType Type() const; - - void Compress(); // If it was a full matrix, compresses, changing Type() to - // kCompressedMatrix; otherwise does nothing. - - void Uncompress(); // If it was a compressed matrix, uncompresses, changing - // Type() to kFullMatrix; otherwise does nothing. - - void Write(std::ostream &os, bool binary) const; - - - /// Note: if you write a compressed matrix in text form, it will be read as - /// a regular full matrix. - void Read(std::istream &is, bool binary); - - /// Returns the contents as a SparseMatrix. This will only work if - /// Type() returns kSparseMatrix, or NumRows() == 0; otherwise it will crash. - const SparseMatrix &GetSparseMatrix() const; - - /// Swaps the with the given SparseMatrix. This will only work if - /// Type() returns kSparseMatrix, or NumRows() == 0. - void SwapSparseMatrix(SparseMatrix *smat); - - /// Returns the contents as a compressed matrix. This will only work if - /// Type() returns kCompressedMatrix, or NumRows() == 0; otherwise it will - /// crash. - const CompressedMatrix &GetCompressedMatrix() const; - - /// Swaps the with the given CompressedMatrix. This will only work if - /// Type() returns kCompressedMatrix, or NumRows() == 0. - void SwapCompressedMatrix(CompressedMatrix *cmat); - - /// Returns the contents as a Matrix. This will only work if - /// Type() returns kFullMatrix, or NumRows() == 0; otherwise it will crash. - const Matrix& GetFullMatrix() const; - - /// Outputs the contents as a matrix. This will work regardless of - /// Type(). Sizes its output, unlike CopyToMat(). - void GetMatrix(Matrix *mat) const; - - /// Swaps the with the given Matrix. This will only work if - /// Type() returns kFullMatrix, or NumRows() == 0. - void SwapFullMatrix(Matrix *mat); - - /// Copies contents, regardless of type, to "mat", which must be correctly - /// sized. See also GetMatrix(), which will size its output for you. - void CopyToMat(MatrixBase *mat, - MatrixTransposeType trans = kNoTrans) const; - - /// Copies contents, regardless of type, to "cu_mat", which must be - /// correctly sized. Implemented in ../cudamatrix/cu-sparse-matrix.cc - void CopyToMat(CuMatrixBase *cu_mat, - MatrixTransposeType trans = kNoTrans) const; - - /// Adds alpha times *this to mat. - void AddToMat(BaseFloat alpha, MatrixBase *mat, - MatrixTransposeType trans = kNoTrans) const; - - /// Adds alpha times *this to cu_mat. - /// Implemented in ../cudamatrix/cu-sparse-matrix.cc - void AddToMat(BaseFloat alpha, CuMatrixBase *cu_mat, - MatrixTransposeType trans = kNoTrans) const; - - /// Scale each element of matrix by alpha. - void Scale(BaseFloat alpha); - - /// Assignment from regular matrix. - GeneralMatrix &operator= (const MatrixBase &mat); - - /// Assignment from compressed matrix. - GeneralMatrix &operator= (const CompressedMatrix &mat); - - /// Assignment from SparseMatrix - GeneralMatrix &operator= (const SparseMatrix &smat); - - MatrixIndexT NumRows() const; - - MatrixIndexT NumCols() const; - - explicit GeneralMatrix(const MatrixBase &mat) { *this = mat; } - - explicit GeneralMatrix(const CompressedMatrix &cmat) { *this = cmat; } - - explicit GeneralMatrix(const SparseMatrix &smat) { *this = smat; } - - GeneralMatrix() { } - // Assignment operator. - GeneralMatrix &operator =(const GeneralMatrix &other); - // Copy constructor - GeneralMatrix(const GeneralMatrix &other) { *this = other; } - // Sets to the empty matrix. - void Clear(); - // shallow swap - void Swap(GeneralMatrix *other); - private: - // We don't explicitly store the type of the matrix. Rather, we make - // sure that only one of the matrices is ever nonempty, and the Type() - // returns that one, or kFullMatrix if all are empty. - Matrix mat_; - CompressedMatrix cmat_; - SparseMatrix smat_; -}; - - -/// Appends all the matrix rows of a list of GeneralMatrixes, to get a single -/// GeneralMatrix. Preserves sparsity if all inputs were sparse (or empty). -/// Does not preserve compression, if inputs were compressed; you have to -/// re-compress manually, if that's what you need. -void AppendGeneralMatrixRows(const std::vector &src, - GeneralMatrix *mat); - - -/// Outputs a SparseMatrix containing only the rows r of "in" such that -/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and rows -/// must contain at least one "true" element. -template -void FilterSparseMatrixRows(const SparseMatrix &in, - const std::vector &keep_rows, - SparseMatrix *out); - -/// Outputs a Matrix containing only the rows r of "in" such that -/// keep_keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and -/// keep_rows must contain at least one "true" element. -template -void FilterMatrixRows(const Matrix &in, - const std::vector &keep_rows, - Matrix *out); - -/// Outputs a Matrix containing only the rows r of "in" such that -/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and rows -/// must contain at least one "true" element. -void FilterCompressedMatrixRows(const CompressedMatrix &in, - const std::vector &keep_rows, - Matrix *out); - - -/// Outputs a GeneralMatrix containing only the rows r of "in" such that -/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and -/// keep_rows must contain at least one "true" element. If in.Type() is -/// kCompressedMatrix, the result will not be compressed; otherwise, the type -/// is preserved. -void FilterGeneralMatrixRows(const GeneralMatrix &in, - const std::vector &keep_rows, - GeneralMatrix *out); - -/// This function extracts a row-range of a GeneralMatrix and writes -/// as a GeneralMatrix containing the same type of underlying -/// matrix. If the row-range is partly outside the row-range of 'in' -/// (i.e. if row_offset < 0 or row_offset + num_rows > in.NumRows()) -/// then it will pad with copies of the first and last row as -/// needed. -/// This is more efficient than un-compressing and -/// re-compressing the underlying CompressedMatrix, and causes -/// less accuracy loss due to re-compression (no loss in most cases). -void ExtractRowRangeWithPadding( - const GeneralMatrix &in, - int32 row_offset, - int32 num_rows, - GeneralMatrix *out); - - -/// @} end of \addtogroup matrix_group - - -} // namespace kaldi - -#endif // KALDI_MATRIX_SPARSE_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/srfft.cc b/speechx/speechx/kaldi/matrix/srfft.cc deleted file mode 100644 index f6189496..00000000 --- a/speechx/speechx/kaldi/matrix/srfft.cc +++ /dev/null @@ -1,440 +0,0 @@ -// matrix/srfft.cc - -// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// - -// This file includes a modified version of code originally published in Malvar, -// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The -// current copyright holder of the original code, Henrique S. Malvar, has given -// his permission for the release of this modified version under the Apache -// License v2.0. - - -#include "matrix/srfft.h" -#include "matrix/matrix-functions.h" - -namespace kaldi { - - -template -SplitRadixComplexFft::SplitRadixComplexFft(MatrixIndexT N) { - if ( (N & (N-1)) != 0 || N <= 1) - KALDI_ERR << "SplitRadixComplexFft called with invalid number of points " - << N; - N_ = N; - logn_ = 0; - while (N > 1) { - N >>= 1; - logn_ ++; - } - ComputeTables(); -} - -template -SplitRadixComplexFft::SplitRadixComplexFft( - const SplitRadixComplexFft &other): - N_(other.N_), logn_(other.logn_) { - // This code duplicates tables from a previously computed object. - // Compare with the code in ComputeTables(). - MatrixIndexT lg2 = logn_ >> 1; - if (logn_ & 1) lg2++; - MatrixIndexT brseed_size = 1 << lg2; - brseed_ = new MatrixIndexT[brseed_size]; - std::memcpy(brseed_, other.brseed_, sizeof(MatrixIndexT) * brseed_size); - - if (logn_ < 4) { - tab_ = NULL; - } else { - tab_ = new Real*[logn_ - 3]; - for (MatrixIndexT i = logn_; i >= 4 ; i--) { - MatrixIndexT m = 1 << i, m2 = m / 2, m4 = m2 / 2; - MatrixIndexT this_array_size = 6 * (m4 - 2); - tab_[i-4] = new Real[this_array_size]; - std::memcpy(tab_[i-4], other.tab_[i-4], - sizeof(Real) * this_array_size); - } - } -} - -template -void SplitRadixComplexFft::ComputeTables() { - MatrixIndexT imax, lg2, i, j; - MatrixIndexT m, m2, m4, m8, nel, n; - Real *cn, *spcn, *smcn, *c3n, *spc3n, *smc3n; - Real ang, c, s; - - lg2 = logn_ >> 1; - if (logn_ & 1) lg2++; - brseed_ = new MatrixIndexT[1 << lg2]; - brseed_[0] = 0; - brseed_[1] = 1; - for (j = 2; j <= lg2; j++) { - imax = 1 << (j - 1); - for (i = 0; i < imax; i++) { - brseed_[i] <<= 1; - brseed_[i + imax] = brseed_[i] + 1; - } - } - - if (logn_ < 4) { - tab_ = NULL; - } else { - tab_ = new Real* [logn_-3]; - for (i = logn_; i>=4 ; i--) { - /* Compute a few constants */ - m = 1 << i; m2 = m / 2; m4 = m2 / 2; m8 = m4 /2; - - /* Allocate memory for tables */ - nel = m4 - 2; - - tab_[i-4] = new Real[6*nel]; - - /* Initialize pointers */ - cn = tab_[i-4]; spcn = cn + nel; smcn = spcn + nel; - c3n = smcn + nel; spc3n = c3n + nel; smc3n = spc3n + nel; - - /* Compute tables */ - for (n = 1; n < m4; n++) { - if (n == m8) continue; - ang = n * M_2PI / m; - c = std::cos(ang); s = std::sin(ang); - *cn++ = c; *spcn++ = - (s + c); *smcn++ = s - c; - ang = 3 * n * M_2PI / m; - c = std::cos(ang); s = std::sin(ang); - *c3n++ = c; *spc3n++ = - (s + c); *smc3n++ = s - c; - } - } - } -} - -template -SplitRadixComplexFft::~SplitRadixComplexFft() { - delete [] brseed_; - if (tab_ != NULL) { - for (MatrixIndexT i = 0; i < logn_-3; i++) - delete [] tab_[i]; - delete [] tab_; - } -} - -template -void SplitRadixComplexFft::Compute(Real *xr, Real *xi, bool forward) const { - if (!forward) { // reverse real and imaginary parts for complex FFT. - Real *tmp = xr; - xr = xi; - xi = tmp; - } - ComputeRecursive(xr, xi, logn_); - if (logn_ > 1) { - BitReversePermute(xr, logn_); - BitReversePermute(xi, logn_); - } -} - -template -void SplitRadixComplexFft::Compute(Real *x, bool forward, - std::vector *temp_buffer) const { - KALDI_ASSERT(temp_buffer != NULL); - if (temp_buffer->size() != N_) - temp_buffer->resize(N_); - Real *temp_ptr = &((*temp_buffer)[0]); - for (MatrixIndexT i = 0; i < N_; i++) { - x[i] = x[i * 2]; // put the real part in the first half of x. - temp_ptr[i] = x[i * 2 + 1]; // put the imaginary part in temp_buffer. - } - // copy the imaginary part back to the second half of x. - memcpy(static_cast(x + N_), - static_cast(temp_ptr), - sizeof(Real) * N_); - - Compute(x, x + N_, forward); - // Now change the format back to interleaved. - memcpy(static_cast(temp_ptr), - static_cast(x + N_), - sizeof(Real) * N_); - for (MatrixIndexT i = N_-1; i > 0; i--) { // don't include 0, - // in case MatrixIndexT is unsigned, the loop would not terminate. - // Treat it as a special case. - x[i*2] = x[i]; - x[i*2 + 1] = temp_ptr[i]; - } - x[1] = temp_ptr[0]; // special case of i = 0. -} - -template -void SplitRadixComplexFft::Compute(Real *x, bool forward) { - this->Compute(x, forward, &temp_buffer_); -} - -template -void SplitRadixComplexFft::BitReversePermute(Real *x, MatrixIndexT logn) const { - MatrixIndexT i, j, lg2, n; - MatrixIndexT off, fj, gno, *brp; - Real tmp, *xp, *xq; - - lg2 = logn >> 1; - n = 1 << lg2; - if (logn & 1) lg2++; - - /* Unshuffling loop */ - for (off = 1; off < n; off++) { - fj = n * brseed_[off]; i = off; j = fj; - tmp = x[i]; x[i] = x[j]; x[j] = tmp; - xp = &x[i]; - brp = &(brseed_[1]); - for (gno = 1; gno < brseed_[off]; gno++) { - xp += n; - j = fj + *brp++; - xq = x + j; - tmp = *xp; *xp = *xq; *xq = tmp; - } - } -} - - -template -void SplitRadixComplexFft::ComputeRecursive(Real *xr, Real *xi, MatrixIndexT logn) const { - - MatrixIndexT m, m2, m4, m8, nel, n; - Real *xr1, *xr2, *xi1, *xi2; - Real *cn = nullptr, *spcn = nullptr, *smcn = nullptr, *c3n = nullptr, - *spc3n = nullptr, *smc3n = nullptr; - Real tmp1, tmp2; - Real sqhalf = M_SQRT1_2; - - /* Check range of logn */ - if (logn < 0) - KALDI_ERR << "Error: logn is out of bounds in SRFFT"; - - /* Compute trivial cases */ - if (logn < 3) { - if (logn == 2) { /* length m = 4 */ - xr2 = xr + 2; - xi2 = xi + 2; - tmp1 = *xr + *xr2; - *xr2 = *xr - *xr2; - *xr = tmp1; - tmp1 = *xi + *xi2; - *xi2 = *xi - *xi2; - *xi = tmp1; - xr1 = xr + 1; - xi1 = xi + 1; - xr2++; - xi2++; - tmp1 = *xr1 + *xr2; - *xr2 = *xr1 - *xr2; - *xr1 = tmp1; - tmp1 = *xi1 + *xi2; - *xi2 = *xi1 - *xi2; - *xi1 = tmp1; - xr2 = xr + 1; - xi2 = xi + 1; - tmp1 = *xr + *xr2; - *xr2 = *xr - *xr2; - *xr = tmp1; - tmp1 = *xi + *xi2; - *xi2 = *xi - *xi2; - *xi = tmp1; - xr1 = xr + 2; - xi1 = xi + 2; - xr2 = xr + 3; - xi2 = xi + 3; - tmp1 = *xr1 + *xi2; - tmp2 = *xi1 + *xr2; - *xi1 = *xi1 - *xr2; - *xr2 = *xr1 - *xi2; - *xr1 = tmp1; - *xi2 = tmp2; - return; - } - else if (logn == 1) { /* length m = 2 */ - xr2 = xr + 1; - xi2 = xi + 1; - tmp1 = *xr + *xr2; - *xr2 = *xr - *xr2; - *xr = tmp1; - tmp1 = *xi + *xi2; - *xi2 = *xi - *xi2; - *xi = tmp1; - return; - } - else if (logn == 0) return; /* length m = 1 */ - } - - /* Compute a few constants */ - m = 1 << logn; m2 = m / 2; m4 = m2 / 2; m8 = m4 /2; - - - /* Step 1 */ - xr1 = xr; xr2 = xr1 + m2; - xi1 = xi; xi2 = xi1 + m2; - for (n = 0; n < m2; n++) { - tmp1 = *xr1 + *xr2; - *xr2 = *xr1 - *xr2; - xr2++; - *xr1++ = tmp1; - tmp2 = *xi1 + *xi2; - *xi2 = *xi1 - *xi2; - xi2++; - *xi1++ = tmp2; - } - - /* Step 2 */ - xr1 = xr + m2; xr2 = xr1 + m4; - xi1 = xi + m2; xi2 = xi1 + m4; - for (n = 0; n < m4; n++) { - tmp1 = *xr1 + *xi2; - tmp2 = *xi1 + *xr2; - *xi1 = *xi1 - *xr2; - xi1++; - *xr2++ = *xr1 - *xi2; - *xr1++ = tmp1; - *xi2++ = tmp2; - // xr1++; xr2++; xi1++; xi2++; - } - - /* Steps 3 & 4 */ - xr1 = xr + m2; xr2 = xr1 + m4; - xi1 = xi + m2; xi2 = xi1 + m4; - if (logn >= 4) { - nel = m4 - 2; - cn = tab_[logn-4]; spcn = cn + nel; smcn = spcn + nel; - c3n = smcn + nel; spc3n = c3n + nel; smc3n = spc3n + nel; - } - xr1++; xr2++; xi1++; xi2++; - // xr1++; xi1++; - for (n = 1; n < m4; n++) { - if (n == m8) { - tmp1 = sqhalf * (*xr1 + *xi1); - *xi1 = sqhalf * (*xi1 - *xr1); - *xr1 = tmp1; - tmp2 = sqhalf * (*xi2 - *xr2); - *xi2 = -sqhalf * (*xr2 + *xi2); - *xr2 = tmp2; - } else { - tmp2 = *cn++ * (*xr1 + *xi1); - tmp1 = *spcn++ * *xr1 + tmp2; - *xr1 = *smcn++ * *xi1 + tmp2; - *xi1 = tmp1; - tmp2 = *c3n++ * (*xr2 + *xi2); - tmp1 = *spc3n++ * *xr2 + tmp2; - *xr2 = *smc3n++ * *xi2 + tmp2; - *xi2 = tmp1; - } - xr1++; xr2++; xi1++; xi2++; - } - - /* Call ssrec again with half DFT length */ - ComputeRecursive(xr, xi, logn-1); - - /* Call ssrec again twice with one quarter DFT length. - Constants have to be recomputed, because they are static! */ - // m = 1 << logn; m2 = m / 2; - ComputeRecursive(xr + m2, xi + m2, logn - 2); - // m = 1 << logn; - m4 = 3 * (m / 4); - ComputeRecursive(xr + m4, xi + m4, logn - 2); -} - - -template -void SplitRadixRealFft::Compute(Real *data, bool forward) { - Compute(data, forward, &this->temp_buffer_); -} - - -// This code is mostly the same as the RealFft function. It would be -// possible to replace it with more efficient code from Rico's book. -template -void SplitRadixRealFft::Compute(Real *data, bool forward, - std::vector *temp_buffer) const { - MatrixIndexT N = N_, N2 = N/2; - KALDI_ASSERT(N%2 == 0); - if (forward) // call to base class - SplitRadixComplexFft::Compute(data, true, temp_buffer); - - Real rootN_re, rootN_im; // exp(-2pi/N), forward; exp(2pi/N), backward - int forward_sign = forward ? -1 : 1; - ComplexImExp(static_cast(M_2PI/N *forward_sign), &rootN_re, &rootN_im); - Real kN_re = -forward_sign, kN_im = 0.0; // exp(-2pik/N), forward; exp(-2pik/N), backward - // kN starts out as 1.0 for forward algorithm but -1.0 for backward. - for (MatrixIndexT k = 1; 2*k <= N2; k++) { - ComplexMul(rootN_re, rootN_im, &kN_re, &kN_im); - - Real Ck_re, Ck_im, Dk_re, Dk_im; - // C_k = 1/2 (B_k + B_{N/2 - k}^*) : - Ck_re = 0.5 * (data[2*k] + data[N - 2*k]); - Ck_im = 0.5 * (data[2*k + 1] - data[N - 2*k + 1]); - // re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})): - Dk_re = 0.5 * (data[2*k + 1] + data[N - 2*k + 1]); - // im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) - Dk_im =-0.5 * (data[2*k] - data[N - 2*k]); - // A_k = C_k + 1^(k/N) D_k: - data[2*k] = Ck_re; // A_k <-- C_k - data[2*k+1] = Ck_im; - // now A_k += D_k 1^(k/N) - ComplexAddProduct(Dk_re, Dk_im, kN_re, kN_im, &(data[2*k]), &(data[2*k+1])); - - MatrixIndexT kdash = N2 - k; - if (kdash != k) { - // Next we handle the index k' = N/2 - k. This is necessary - // to do now, to avoid invalidating data that we will later need. - // The quantities C_{k'} and D_{k'} are just the conjugates of C_k - // and D_k, so the equations are simple modifications of the above, - // replacing Ck_im and Dk_im with their negatives. - data[2*kdash] = Ck_re; // A_k' <-- C_k' - data[2*kdash+1] = -Ck_im; - // now A_k' += D_k' 1^(k'/N) - // We use 1^(k'/N) = 1^((N/2 - k) / N) = 1^(1/2) 1^(-k/N) = -1 * (1^(k/N))^* - // so it's the same as 1^(k/N) but with the real part negated. - ComplexAddProduct(Dk_re, -Dk_im, -kN_re, kN_im, &(data[2*kdash]), &(data[2*kdash+1])); - } - } - - { // Now handle k = 0. - // In simple terms: after the complex fft, data[0] becomes the sum of real - // parts input[0], input[2]... and data[1] becomes the sum of imaginary - // pats input[1], input[3]... - // "zeroth" [A_0] is just the sum of input[0]+input[1]+input[2].. - // and "n2th" [A_{N/2}] is input[0]-input[1]+input[2]... . - Real zeroth = data[0] + data[1], - n2th = data[0] - data[1]; - data[0] = zeroth; - data[1] = n2th; - if (!forward) { - data[0] /= 2; - data[1] /= 2; - } - } - if (!forward) { // call to base class - SplitRadixComplexFft::Compute(data, false, temp_buffer); - for (MatrixIndexT i = 0; i < N; i++) - data[i] *= 2.0; - // This is so we get a factor of N increase, rather than N/2 which we would - // otherwise get from [ComplexFft, forward] + [ComplexFft, backward] in dimension N/2. - // It's for consistency with our normal FFT convensions. - } -} - -template class SplitRadixComplexFft; -template class SplitRadixComplexFft; -template class SplitRadixRealFft; -template class SplitRadixRealFft; - - -} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/srfft.h b/speechx/speechx/kaldi/matrix/srfft.h deleted file mode 100644 index 98ff782a..00000000 --- a/speechx/speechx/kaldi/matrix/srfft.h +++ /dev/null @@ -1,141 +0,0 @@ -// matrix/srfft.h - -// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. -// 2014 Daniel Povey -// -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -// -// This file includes a modified version of code originally published in Malvar, -// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The -// current copyright holder of the original code, Henrique S. Malvar, has given -// his permission for the release of this modified version under the Apache -// License v2.0. - -#ifndef KALDI_MATRIX_SRFFT_H_ -#define KALDI_MATRIX_SRFFT_H_ - -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" - -namespace kaldi { - -/// @addtogroup matrix_funcs_misc -/// @{ - - -// This class is based on code by Henrique (Rico) Malvar, from his book -// "Signal Processing with Lapped Transforms" (1992). Copied with -// permission, optimized by Go Vivace Inc., and converted into C++ by -// Microsoft Corporation -// This is a more efficient way of doing the complex FFT than ComplexFft -// (declared in matrix-functios.h), but it only works for powers of 2. -// Note: in multi-threaded code, you would need to have one of these objects per -// thread, because multiple calls to Compute in parallel would not work. -template -class SplitRadixComplexFft { - public: - typedef MatrixIndexT Integer; - - // N is the number of complex points (must be a power of two, or this - // will crash). Note that the constructor does some work so it's best to - // initialize the object once and do the computation many times. - SplitRadixComplexFft(Integer N); - - // Copy constructor - SplitRadixComplexFft(const SplitRadixComplexFft &other); - - // Does the FFT computation, given pointers to the real and - // imaginary parts. If "forward", do the forward FFT; else - // do the inverse FFT (without the 1/N factor). - // xr and xi are pointers to zero-based arrays of size N, - // containing the real and imaginary parts - // respectively. - void Compute(Real *xr, Real *xi, bool forward) const; - - // This version of Compute takes a single array of size N*2, - // containing [ r0 im0 r1 im1 ... ]. Otherwise its behavior is the - // same as the version above. - void Compute(Real *x, bool forward); - - - // This version of Compute is const; it operates on an array of size N*2 - // containing [ r0 im0 r1 im1 ... ], but it uses the argument "temp_buffer" as - // temporary storage instead of a class-member variable. It will allocate it if - // needed. - void Compute(Real *x, bool forward, std::vector *temp_buffer) const; - - ~SplitRadixComplexFft(); - - protected: - // temp_buffer_ is allocated only if someone calls Compute with only one Real* - // argument and we need a temporary buffer while creating interleaved data. - std::vector temp_buffer_; - private: - void ComputeTables(); - void ComputeRecursive(Real *xr, Real *xi, Integer logn) const; - void BitReversePermute(Real *x, Integer logn) const; - - Integer N_; - Integer logn_; // log(N) - - Integer *brseed_; - // brseed is Evans' seed table, ref: (Ref: D. M. W. - // Evans, "An improved digit-reversal permutation algorithm ...", - // IEEE Trans. ASSP, Aug. 1987, pp. 1120-1125). - Real **tab_; // Tables of butterfly coefficients. - - // Disallow assignment. - SplitRadixComplexFft &operator =(const SplitRadixComplexFft &other); -}; - -template -class SplitRadixRealFft: private SplitRadixComplexFft { - public: - SplitRadixRealFft(MatrixIndexT N): // will fail unless N>=4 and N is a power of 2. - SplitRadixComplexFft (N/2), N_(N) { } - - // Copy constructor - SplitRadixRealFft(const SplitRadixRealFft &other): - SplitRadixComplexFft(other), N_(other.N_) { } - - /// If forward == true, this function transforms from a sequence of N real points to its complex fourier - /// transform; otherwise it goes in the reverse direction. If you call it - /// in the forward and then reverse direction and multiply by 1.0/N, you - /// will get back the original data. - /// The interpretation of the complex-FFT data is as follows: the array - /// is a sequence of complex numbers C_n of length N/2 with (real, im) format, - /// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. - void Compute(Real *x, bool forward); - - - /// This is as the other Compute() function, but it is a const version that - /// uses a user-supplied buffer. - void Compute(Real *x, bool forward, std::vector *temp_buffer) const; - - private: - // Disallow assignment. - SplitRadixRealFft &operator =(const SplitRadixRealFft &other); - int N_; -}; - - -/// @} end of "addtogroup matrix_funcs_misc" - -} // end namespace kaldi - - -#endif - diff --git a/speechx/speechx/kaldi/matrix/tp-matrix.cc b/speechx/speechx/kaldi/matrix/tp-matrix.cc deleted file mode 100644 index 6e34dc64..00000000 --- a/speechx/speechx/kaldi/matrix/tp-matrix.cc +++ /dev/null @@ -1,145 +0,0 @@ -// matrix/tp-matrix.cc - -// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation -// Saarland University; Yanmin Qian; Haihua Xu - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include "matrix/tp-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/kaldi-matrix.h" -#include "matrix/cblas-wrappers.h" - - -namespace kaldi { - -#ifndef HAVE_ATLAS -template -void TpMatrix::Invert() { - // these are CLAPACK types - KaldiBlasInt result; - KaldiBlasInt rows = static_cast(this->num_rows_); - - // clapack call - // NOTE: Even though "U" is for upper, lapack assumes column-wise storage - // of the data. We have a row-wise storage, therefore, we need to "invert" - clapack_Xtptri(&rows, this->data_, &result); - - if (result < 0) { - KALDI_ERR << "Call to CLAPACK stptri_ function failed"; - } else if (result > 0) { - KALDI_ERR << "Matrix is singular"; - } -} -#else -template -void TpMatrix::Invert() { - // ATLAS doesn't implement triangular matrix inversion in packed - // format, so we temporarily put in non-packed format. - Matrix tmp(*this); - int rows = static_cast(this->num_rows_); - - // ATLAS call. It's really row-major ordering and a lower triangular matrix, - // but there is some weirdness with Fortran-style indexing that we need to - // take account of, so everything gets swapped. - int result = clapack_Xtrtri( rows, tmp.Data(), tmp.Stride()); - // Let's hope ATLAS has the same return value conventions as clapack. - // I couldn't find any documentation online. - if (result < 0) { - KALDI_ERR << "Call to ATLAS strtri function failed"; - } else if (result > 0) { - KALDI_ERR << "Matrix is singular"; - } - (*this).CopyFromMat(tmp); -} -#endif - -template -Real TpMatrix::Determinant() { - double det = 1.0; - for (MatrixIndexT i = 0; iNumRows(); i++) { - det *= (*this)(i, i); - } - return static_cast(det); -} - - -template -void TpMatrix::Swap(TpMatrix *other) { - std::swap(this->data_, other->data_); - std::swap(this->num_rows_, other->num_rows_); -} - - -template -void TpMatrix::Cholesky(const SpMatrix &orig) { - KALDI_ASSERT(orig.NumRows() == this->NumRows()); - MatrixIndexT n = this->NumRows(); - this->SetZero(); - Real *data = this->data_, *jdata = data; // start of j'th row of matrix. - const Real *orig_jdata = orig.Data(); // start of j'th row of matrix. - for (MatrixIndexT j = 0; j < n; j++, jdata += j, orig_jdata += j) { - Real *kdata = data; // start of k'th row of matrix. - Real d(0.0); - for (MatrixIndexT k = 0; k < j; k++, kdata += k) { - Real s = cblas_Xdot(k, kdata, 1, jdata, 1); - // (*this)(j, k) = s = (orig(j, k) - s)/(*this)(k, k); - jdata[k] = s = (orig_jdata[k] - s)/kdata[k]; - d = d + s*s; - } - // d = orig(j, j) - d; - d = orig_jdata[j] - d; - - if (d >= 0.0) { - // (*this)(j, j) = std::sqrt(d); - jdata[j] = std::sqrt(d); - } else { - KALDI_ERR << "Cholesky decomposition failed. Maybe matrix " - "is not positive definite."; - } - } -} - -template -void TpMatrix::CopyFromMat(const MatrixBase &M, - MatrixTransposeType Trans) { - if (Trans == kNoTrans) { - KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); - MatrixIndexT D = this->NumRows(); - const Real *in_i = M.Data(); - MatrixIndexT stride = M.Stride(); - Real *out_i = this->data_; - for (MatrixIndexT i = 0; i < D; i++, in_i += stride, out_i += i) - for (MatrixIndexT j = 0; j <= i; j++) - out_i[j] = in_i[j]; - } else { - KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); - MatrixIndexT D = this->NumRows(); - const Real *in_i = M.Data(); - MatrixIndexT stride = M.Stride(); - Real *out_i = this->data_; - for (MatrixIndexT i = 0; i < D; i++, in_i++, out_i += i) { - for (MatrixIndexT j = 0; j <= i; j++) - out_i[j] = in_i[stride*j]; - } - } -} - - -template class TpMatrix; -template class TpMatrix; - -} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/tp-matrix.h b/speechx/speechx/kaldi/matrix/tp-matrix.h deleted file mode 100644 index e3b08701..00000000 --- a/speechx/speechx/kaldi/matrix/tp-matrix.h +++ /dev/null @@ -1,134 +0,0 @@ -// matrix/tp-matrix.h - -// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation; -// Saarland University; Yanmin Qian; Haihua Xu -// 2013 Johns Hopkins Universith (author: Daniel Povey) - - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. -#ifndef KALDI_MATRIX_TP_MATRIX_H_ -#define KALDI_MATRIX_TP_MATRIX_H_ - - -#include "matrix/packed-matrix.h" - -namespace kaldi { -/// \addtogroup matrix_group -/// @{ - -template class TpMatrix; - -/// @brief Packed symetric matrix class -template -class TpMatrix : public PackedMatrix { - friend class CuTpMatrix; - friend class CuTpMatrix; - public: - TpMatrix() : PackedMatrix() {} - explicit TpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) - : PackedMatrix(r, resize_type) {} - TpMatrix(const TpMatrix& orig) : PackedMatrix(orig) {} - - /// Copy constructor from CUDA TpMatrix - /// This is defined in ../cudamatrix/cu-tp-matrix.cc - explicit TpMatrix(const CuTpMatrix &cu); - - - template explicit TpMatrix(const TpMatrix& orig) - : PackedMatrix(orig) {} - - Real operator() (MatrixIndexT r, MatrixIndexT c) const { - if (static_cast(c) > - static_cast(r)) { - KALDI_ASSERT(static_cast(c) < - static_cast(this->num_rows_)); - return 0; - } - KALDI_ASSERT(static_cast(r) < - static_cast(this->num_rows_)); - // c<=r now so don't have to check c. - return *(this->data_ + (r*(r+1)) / 2 + c); - // Duplicating code from PackedMatrix.h - } - - Real &operator() (MatrixIndexT r, MatrixIndexT c) { - KALDI_ASSERT(static_cast(r) < - static_cast(this->num_rows_)); - KALDI_ASSERT(static_cast(c) <= - static_cast(r) && - "you cannot access the upper triangle of TpMatrix using " - "a non-const matrix object."); - return *(this->data_ + (r*(r+1)) / 2 + c); - // Duplicating code from PackedMatrix.h - } - // Note: Cholesky may throw KaldiFatalError. - void Cholesky(const SpMatrix& orig); - - void Invert(); - - // Inverts in double precision. - void InvertDouble() { - TpMatrix dmat(*this); - dmat.Invert(); - (*this).CopyFromTp(dmat); - } - - /// Shallow swap - void Swap(TpMatrix *other); - - /// Returns the determinant of the matrix (product of diagonals) - Real Determinant(); - - /// CopyFromMat copies the lower triangle of M into *this - /// (or the upper triangle, if Trans == kTrans). - void CopyFromMat(const MatrixBase &M, - MatrixTransposeType Trans = kNoTrans); - - /// This is implemented in ../cudamatrix/cu-tp-matrix.cc - void CopyFromMat(const CuTpMatrix &other); - - /// CopyFromTp copies another triangular matrix into this one. - void CopyFromTp(const TpMatrix &other) { - PackedMatrix::CopyFromPacked(other); - } - - template void CopyFromTp(const TpMatrix &other) { - PackedMatrix::CopyFromPacked(other); - } - - /// AddTp does *this += alpha * M. - void AddTp(const Real alpha, const TpMatrix &M) { - this->AddPacked(alpha, M); - } - - TpMatrix& operator=(const TpMatrix &other) { - PackedMatrix::operator=(other); - return *this; - } - - using PackedMatrix::Scale; - - void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { - PackedMatrix::Resize(nRows, resize_type); - } -}; - -/// @} end of "addtogroup matrix_group". - -} // namespace kaldi - - -#endif diff --git a/speechx/speechx/kaldi/util/kaldi-holder-inl.h b/speechx/speechx/kaldi/util/kaldi-holder-inl.h index 134cdd93..9b441ad4 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder-inl.h +++ b/speechx/speechx/kaldi/util/kaldi-holder-inl.h @@ -754,53 +754,53 @@ class TokenVectorHolder { }; -class HtkMatrixHolder { - public: - typedef std::pair, HtkHeader> T; - - HtkMatrixHolder() {} - - static bool Write(std::ostream &os, bool binary, const T &t) { - if (!binary) - KALDI_ERR << "Non-binary HTK-format write not supported."; - bool ans = WriteHtk(os, t.first, t.second); - if (!ans) - KALDI_WARN << "Error detected writing HTK-format matrix."; - return ans; - } - - void Clear() { t_.first.Resize(0, 0); } - - // Reads into the holder. - bool Read(std::istream &is) { - bool ans = ReadHtk(is, &t_.first, &t_.second); - if (!ans) { - KALDI_WARN << "Error detected reading HTK-format matrix."; - return false; - } - return ans; - } - - // HTK-format matrices only read in binary. - static bool IsReadInBinary() { return true; } - - T &Value() { return t_; } - - void Swap(HtkMatrixHolder *other) { - t_.first.Swap(&(other->t_.first)); - std::swap(t_.second, other->t_.second); - } - - bool ExtractRange(const HtkMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - // Default destructor. - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); - T t_; -}; +//class HtkMatrixHolder { + //public: + //typedef std::pair, HtkHeader> T; + + //HtkMatrixHolder() {} + + //static bool Write(std::ostream &os, bool binary, const T &t) { + //if (!binary) + //KALDI_ERR << "Non-binary HTK-format write not supported."; + //bool ans = WriteHtk(os, t.first, t.second); + //if (!ans) + //KALDI_WARN << "Error detected writing HTK-format matrix."; + //return ans; + //} + + //void Clear() { t_.first.Resize(0, 0); } + + //// Reads into the holder. + //bool Read(std::istream &is) { + //bool ans = ReadHtk(is, &t_.first, &t_.second); + //if (!ans) { + //KALDI_WARN << "Error detected reading HTK-format matrix."; + //return false; + //} + //return ans; + //} + + //// HTK-format matrices only read in binary. + //static bool IsReadInBinary() { return true; } + + //T &Value() { return t_; } + + //void Swap(HtkMatrixHolder *other) { + //t_.first.Swap(&(other->t_.first)); + //std::swap(t_.second, other->t_.second); + //} + + //bool ExtractRange(const HtkMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + //// Default destructor. + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + //T t_; +//}; // SphinxMatrixHolder can be used to read and write feature files in // CMU Sphinx format. 13-dimensional big-endian features are assumed. @@ -813,104 +813,104 @@ class HtkMatrixHolder { // be no problem, because the usage help of Sphinx' "wave2feat" for example // says that Sphinx features are always big endian. // Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h -template class SphinxMatrixHolder { - public: - typedef Matrix T; - - SphinxMatrixHolder() {} - - void Clear() { feats_.Resize(0, 0); } - - // Writes Sphinx-format features - static bool Write(std::ostream &os, bool binary, const T &m) { - if (!binary) { - KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; - return false; - } - - int32 size = m.NumRows() * m.NumCols(); - if (MachineIsLittleEndian()) - KALDI_SWAP4(size); - // write the header - os.write(reinterpret_cast (&size), sizeof(size)); - - for (MatrixIndexT i = 0; i < m.NumRows(); i++) { - std::vector tmp(m.NumCols()); - for (MatrixIndexT j = 0; j < m.NumCols(); j++) { - tmp[j] = static_cast(m(i, j)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - } - os.write(reinterpret_cast(&(tmp[0])), - tmp.size() * 4); - } - return true; - } - - // Reads the features into a Kaldi Matrix - bool Read(std::istream &is) { - int32 nmfcc; - - is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(nmfcc); - KALDI_VLOG(2) << "#feats: " << nmfcc; - int32 nfvec = nmfcc / kFeatDim; - if ((nmfcc % kFeatDim) != 0) { - KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; - return false; - } - - feats_.Resize(nfvec, kFeatDim); - for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { - if (sizeof(BaseFloat) == sizeof(float32)) { - is.read(reinterpret_cast (feats_.RowData(i)), - kFeatDim * sizeof(float32)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - if (MachineIsLittleEndian()) { - for (MatrixIndexT j = 0; j < kFeatDim; j++) - KALDI_SWAP4(feats_(i, j)); - } - } else { // KALDI_DOUBLEPRECISION=1 - float32 tmp[kFeatDim]; - is.read(reinterpret_cast (tmp), sizeof(tmp)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - for (MatrixIndexT j = 0; j < kFeatDim; j++) { - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - feats_(i, j) = static_cast(tmp[j]); - } - } - } - - return true; - } - - // Only read in binary - static bool IsReadInBinary() { return true; } - - T &Value() { return feats_; } - - void Swap(SphinxMatrixHolder *other) { - feats_.Swap(&(other->feats_)); - } - - bool ExtractRange(const SphinxMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); - T feats_; -}; +//template class SphinxMatrixHolder { + //public: + //typedef Matrix T; + + //SphinxMatrixHolder() {} + + //void Clear() { feats_.Resize(0, 0); } + + //// Writes Sphinx-format features + //static bool Write(std::ostream &os, bool binary, const T &m) { + //if (!binary) { + //KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + //return false; + //} + + //int32 size = m.NumRows() * m.NumCols(); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(size); + //// write the header + //os.write(reinterpret_cast (&size), sizeof(size)); + + //for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + //std::vector tmp(m.NumCols()); + //for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + //tmp[j] = static_cast(m(i, j)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //} + //os.write(reinterpret_cast(&(tmp[0])), + //tmp.size() * 4); + //} + //return true; + //} + + //// Reads the features into a Kaldi Matrix + //bool Read(std::istream &is) { + //int32 nmfcc; + + //is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(nmfcc); + //KALDI_VLOG(2) << "#feats: " << nmfcc; + //int32 nfvec = nmfcc / kFeatDim; + //if ((nmfcc % kFeatDim) != 0) { + //KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + //return false; + //} + + //feats_.Resize(nfvec, kFeatDim); + //for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + //if (sizeof(BaseFloat) == sizeof(float32)) { + //is.read(reinterpret_cast (feats_.RowData(i)), + //kFeatDim * sizeof(float32)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //if (MachineIsLittleEndian()) { + //for (MatrixIndexT j = 0; j < kFeatDim; j++) + //KALDI_SWAP4(feats_(i, j)); + //} + //} else { // KALDI_DOUBLEPRECISION=1 + //float32 tmp[kFeatDim]; + //is.read(reinterpret_cast (tmp), sizeof(tmp)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //for (MatrixIndexT j = 0; j < kFeatDim; j++) { + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //feats_(i, j) = static_cast(tmp[j]); + //} + //} + //} + + //return true; + //} + + //// Only read in binary + //static bool IsReadInBinary() { return true; } + + //T &Value() { return feats_; } + + //void Swap(SphinxMatrixHolder *other) { + //feats_.Swap(&(other->feats_)); + //} + + //bool ExtractRange(const SphinxMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + //T feats_; +//}; /// @} end "addtogroup holders" diff --git a/speechx/speechx/kaldi/util/kaldi-holder.cc b/speechx/speechx/kaldi/util/kaldi-holder.cc index 577679ef..6b0eebb9 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.cc +++ b/speechx/speechx/kaldi/util/kaldi-holder.cc @@ -85,7 +85,7 @@ bool ParseMatrixRangeSpecifier(const std::string &range, return status; } -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, +/*bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, GeneralMatrix *output) { // We just inspect input's type and forward to the correct implementation // if available. For kSparseMatrix, we do just fairly inefficient conversion @@ -135,6 +135,7 @@ template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, Matrix *); +*/ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output) { diff --git a/speechx/speechx/kaldi/util/kaldi-holder.h b/speechx/speechx/kaldi/util/kaldi-holder.h index f495f27f..a8c42c9f 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.h +++ b/speechx/speechx/kaldi/util/kaldi-holder.h @@ -27,7 +27,6 @@ #include "util/kaldi-io.h" #include "util/text-utils.h" #include "matrix/kaldi-vector.h" -#include "matrix/sparse-matrix.h" namespace kaldi { @@ -214,10 +213,10 @@ class TokenVectorHolder; /// A class for reading/writing HTK-format matrices. /// T == std::pair, HtkHeader> -class HtkMatrixHolder; +//class HtkMatrixHolder; /// A class for reading/writing Sphinx format matrices. -template class SphinxMatrixHolder; +//template class SphinxMatrixHolder; /// This templated function exists so that we can write .scp files with /// 'object ranges' specified: the canonical example is a [first:last] range @@ -249,15 +248,15 @@ bool ExtractObjectRange(const Vector &input, const std::string &range, Vector *output); /// GeneralMatrix is always of type BaseFloat -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, - GeneralMatrix *output); +//bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, + // GeneralMatrix *output); /// CompressedMatrix is always of the type BaseFloat but it is more /// efficient to provide template as it uses CompressedMatrix's own /// conversion to Matrix -template -bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, - Matrix *output); +//template +//bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, + // Matrix *output); // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets diff --git a/speechx/speechx/kaldi/util/table-types.h b/speechx/speechx/kaldi/util/table-types.h index efcdf1b5..665a1327 100644 --- a/speechx/speechx/kaldi/util/table-types.h +++ b/speechx/speechx/kaldi/util/table-types.h @@ -23,7 +23,8 @@ #include "base/kaldi-common.h" #include "util/kaldi-table.h" #include "util/kaldi-holder.h" -#include "matrix/matrix-lib.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" namespace kaldi { @@ -51,8 +52,8 @@ typedef RandomAccessTableReader > > typedef RandomAccessTableReaderMapped > > RandomAccessDoubleMatrixReaderMapped; -typedef TableWriter > - CompressedMatrixWriter; +//typedef TableWriter > + //CompressedMatrixWriter; typedef TableWriter > > BaseFloatVectorWriter; @@ -70,39 +71,39 @@ typedef SequentialTableReader > > typedef RandomAccessTableReader > > RandomAccessDoubleVectorReader; -typedef TableWriter > > - BaseFloatCuMatrixWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuMatrixReaderMapped; - -typedef TableWriter > > - DoubleCuMatrixWriter; -typedef SequentialTableReader > > - SequentialDoubleCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessDoubleCuMatrixReaderMapped; - -typedef TableWriter > > - BaseFloatCuVectorWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuVectorReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuVectorReaderMapped; - -typedef TableWriter > > - DoubleCuVectorWriter; -typedef SequentialTableReader > > - SequentialDoubleCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuVectorReader; +//typedef TableWriter > > + //BaseFloatCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuMatrixReaderMapped; + +//typedef TableWriter > > + //DoubleCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessDoubleCuMatrixReaderMapped; + +//typedef TableWriter > > + //BaseFloatCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuVectorReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuVectorReaderMapped; + +//typedef TableWriter > > + //DoubleCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuVectorReader; typedef TableWriter > Int32Writer; @@ -150,8 +151,6 @@ typedef TableWriter > BoolWriter; typedef SequentialTableReader > SequentialBoolReader; typedef RandomAccessTableReader > RandomAccessBoolReader; - - /// TokenWriter is a writer specialized for std::string where the strings /// are nonempty and whitespace-free. T == std::string typedef TableWriter TokenWriter; @@ -169,14 +168,14 @@ typedef RandomAccessTableReader RandomAccessTokenVectorReader; -typedef TableWriter > - GeneralMatrixWriter; -typedef SequentialTableReader > - SequentialGeneralMatrixReader; -typedef RandomAccessTableReader > - RandomAccessGeneralMatrixReader; -typedef RandomAccessTableReaderMapped > - RandomAccessGeneralMatrixReaderMapped; +//typedef TableWriter > +// GeneralMatrixWriter; +//typedef SequentialTableReader > + // SequentialGeneralMatrixReader; +//typedef RandomAccessTableReader > + // RandomAccessGeneralMatrixReader; +//typedef RandomAccessTableReaderMapped > + // RandomAccessGeneralMatrixReaderMapped; From 8a225b1708507e873e29b62559bb0756419d3ebe Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 18 Jan 2023 16:11:26 +0800 Subject: [PATCH 07/50] [speechx] thread decode (#2839) * fix nnet thread crash && rescore cost time * add nnet thread main --- .../decoder/ctc_prefix_beam_search_decoder.cc | 9 +- speechx/speechx/asr/nnet/CMakeLists.txt | 27 ++-- speechx/speechx/asr/nnet/decodable.cc | 1 - speechx/speechx/asr/nnet/nnet_producer.cc | 54 ++++++- speechx/speechx/asr/nnet/nnet_producer.h | 34 ++++- .../speechx/asr/nnet/u2_nnet_thread_main.cc | 137 ++++++++++++++++++ .../speechx/asr/recognizer/u2_recognizer.cc | 39 ++++- .../speechx/asr/recognizer/u2_recognizer.h | 16 +- .../asr/recognizer/u2_recognizer_main.cc | 4 +- .../recognizer/u2_recognizer_thread_main.cc | 26 ++-- .../common/frontend/compute_fbank_main.cc | 3 +- .../speechx/common/frontend/feature_cache.cc | 40 ++--- .../speechx/common/frontend/feature_cache.h | 20 +-- .../common/frontend/feature_pipeline.cc | 2 +- .../common/frontend/feature_pipeline.h | 1 - 15 files changed, 303 insertions(+), 110 deletions(-) create mode 100644 speechx/speechx/asr/nnet/u2_nnet_thread_main.cc diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 2cef4972..8361f06d 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() { times_.emplace_back(empty); } -void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - +void CTCPrefixBeamSearch::InitDecoder() { + Reset(); +} void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { @@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } @@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 819cc2e8..7306ebf8 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -8,14 +8,21 @@ target_link_libraries(nnet utils) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# test bin -#if(USING_U2) -# set(bin_name u2_nnet_main) -# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) -# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) -# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -#endif() +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + +set(bin_name u2_nnet_thread_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend) + +target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index f01e9049..a140c376 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 6207a6b5..b83b5976 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -22,14 +22,43 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) - : nnet_(nnet), frontend_(frontend) {} + : nnet_(nnet), frontend_(frontend) { + abort_ = false; + Reset(); + thread_ = std::thread(RunNnetEvaluation, this); + } void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); + condition_variable_.notify_one(); +} + +void NnetProducer::UnLock() { + std::unique_lock lock(read_mutex_); + while (frontend_->IsFinished() == false && cache_.empty()) { + condition_read_ready_.wait(lock); + } + return; +} + +void NnetProducer::RunNnetEvaluation(NnetProducer *me) { + me->RunNnetEvaluationInteral(); +} + +void NnetProducer::RunNnetEvaluationInteral() { bool result = false; - do { - result = Compute(); - } while (result); + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(mutex_); + condition_variable_.wait(lock); + do { + result = Compute(); + } while (result); + if (frontend_->IsFinished() == true) { + if (cache_.empty()) finished_ = true; + } + } + LOG(INFO) << "NnetEvaluationInteral exit"; } void NnetProducer::Acceptlikelihood( @@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood( for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { for (size_t col = 0; col < likelihood.NumCols(); ++col) { prob[col] = likelihood(idx, col); - cache_.push_back(prob); } + cache_.push_back(prob); } } bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + condition_variable_.notify_one(); + return flag; +} + +bool NnetProducer::ReadandCompute(std::vector* nnet_prob) { + Compute(); + if (frontend_->IsFinished() && cache_.empty()) finished_ = true; bool flag = cache_.pop(nnet_prob); return flag; } @@ -53,22 +90,23 @@ bool NnetProducer::Compute() { vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. - VLOG(3) << "no feat avalible"; + VLOG(2) << "no feat avalible"; return false; } CHECK_GE(frontend_->Dim(), 0); - VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; size_t nframes = out.logprobs.size() / vocab_dim; - VLOG(2) << "Forward out " << nframes << " decoder frames."; + VLOG(1) << "Forward out " << nframes << " decoder frames."; for (size_t idx = 0; idx < nframes; ++idx) { std::vector logprob( out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + (idx + 1) * vocab_dim); cache_.push_back(logprob); + condition_read_ready_.notify_one(); } return true; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index dd356f95..14c74d04 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -33,27 +33,38 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); + bool ReadandCompute(std::vector* nnet_prob); + static void RunNnetEvaluation(NnetProducer *me); + void RunNnetEvaluationInteral(); + void UnLock(); + + void Wait() { + abort_ = true; + condition_variable_.notify_one(); + if (thread_.joinable()) thread_.join(); + } bool Empty() const { return cache_.empty(); } - void SetFinished() { + void SetInputFinished() { LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); frontend_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); - LOG(INFO) << "compute last feats done."; + condition_variable_.notify_one(); } - bool IsFinished() const { return frontend_->IsFinished(); } + // the compute thread exit + bool IsFinished() const { return finished_; } + + ~NnetProducer() { + if (thread_.joinable()) thread_.join(); + } void Reset() { frontend_->Reset(); nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); + finished_ = false; } void AttentionRescoring(const std::vector>& hyps, @@ -66,6 +77,13 @@ class NnetProducer { std::shared_ptr frontend_; std::shared_ptr nnet_; SafeQueue> cache_; + std::mutex mutex_; + std::mutex read_mutex_; + std::condition_variable condition_variable_; + std::condition_variable condition_read_ready_; + std::thread thread_; + bool finished_; + bool abort_; DISALLOW_COPY_AND_ASSIGN(NnetProducer); }; diff --git a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 00000000..ce523e59 --- /dev/null +++ b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "frontend/feature_pipeline.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); + std::shared_ptr feature_pipeline( + new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer( + new ppspeech::NnetProducer(nnet, feature_pipeline)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while(1) { + std::vector logprobs; + bool isok = nnet_producer->Read(&logprobs); + if (nnet_producer->IsFinished()) break; + if (isok == false) continue; + prob_vec.push_back(logprobs); + } + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].size(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx]; + } + } + nnet_out_writer.Write(utt, nnet_out); + } + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index a7644430..0c5a8941 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) unit_table_ = decoder_->VocabTable(); symbol_table_ = unit_table_; + global_frame_offset_ = 0; input_finished_ = false; + num_frames_ = 0; + result_.clear(); + +} + +U2Recognizer::~U2Recognizer() { + SetInputFinished(); + WaitDecodeFinished(); +} - Reset(); +void U2Recognizer::WaitDecodeFinished() { + if (thread_.joinable()) thread_.join(); } -void U2Recognizer::Reset() { +void U2Recognizer::WaitFinished() { + if (thread_.joinable()) thread_.join(); + nnet_producer_->Wait(); +} + +void U2Recognizer::InitDecoder() { global_frame_offset_ = 0; input_finished_ = false; num_frames_ = 0; @@ -52,6 +68,7 @@ void U2Recognizer::Reset() { decodable_->Reset(); decoder_->Reset(); + thread_ = std::thread(RunDecoderSearch, this); } void U2Recognizer::ResetContinuousDecoding() { @@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() { decoder_->Reset(); } +void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { + me->RunDecoderSearchInternal(); +} + +void U2Recognizer::RunDecoderSearchInternal() { + LOG(INFO) << "DecoderSearchInteral begin"; + while (!nnet_producer_->IsFinished()) { + nnet_producer_->UnLock(); + decoder_->AdvanceDecode(decodable_); + } + Decode(); + LOG(INFO) << "DecoderSearchInteral exit"; +} void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; @@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector& waves) { << " samples."; } - void U2Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); UpdateResult(false); @@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } -void U2Recognizer::SetFinished() { - nnet_producer_->SetFinished(); +void U2Recognizer::SetInputFinished() { + nnet_producer_->SetInputFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index a3bf8aea..57f2c9c5 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,19 +112,21 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); + ~U2Recognizer(); + void InitDecoder(); void ResetContinuousDecoding(); void Accept(const std::vector& waves); void Decode(); void Rescoring(); - std::string GetFinalResult(); std::string GetPartialResult(); - void SetFinished(); + void SetInputFinished(); bool IsFinished() { return input_finished_; } + void WaitDecodeFinished(); + void WaitFinished(); bool DecodedSomething() const { return !result_.empty() && !result_[0].sentence.empty(); @@ -137,18 +139,17 @@ class U2Recognizer { // feature_pipeline_->FrameShift(); } - const std::vector& Result() const { return result_; } + void AttentionRescoring(); private: - void AttentionRescoring(); + static void RunDecoderSearch(U2Recognizer *me); + void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); private: U2RecognizerResource opts_; - // std::shared_ptr resource_; - // U2RecognizerResource resource_; std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; @@ -167,6 +168,7 @@ class U2Recognizer { const int time_stamp_gap_ = 100; bool input_finished_; + std::thread thread_; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index 90c7cc06..178c91db 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) { ppspeech::U2Recognizer recognizer(resource); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer.InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -79,7 +80,7 @@ int main(int argc, char* argv[]) { recognizer.Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); + recognizer.SetInputFinished(); } recognizer.Decode(); if (recognizer.DecodedSomething()) { @@ -100,7 +101,6 @@ int main(int argc, char* argv[]) { std::string result = recognizer.GetFinalResult(); - recognizer.Reset(); if (result.empty()) { // the TokenWriter can not write empty string. diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index a53b4541..3f45294d 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -void decode_func(std::shared_ptr recognizer) { - while (!recognizer->IsFinished()) { - recognizer->Decode(); - usleep(100); - } - recognizer->Decode(); - recognizer->Rescoring(); -} - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -59,7 +51,7 @@ int main(int argc, char* argv[]) { new ppspeech::U2Recognizer(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { - std::thread recognizer_thread(decode_func, recognizer_ptr); + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -74,7 +66,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { @@ -85,21 +76,23 @@ int main(int argc, char* argv[]) { for (int i = 0; i < cur_chunk_size; ++i) { wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); recognizer_ptr->Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer_ptr->SetFinished(); + recognizer_ptr->SetInputFinished(); } // no overlap sample_offset += cur_chunk_size; } CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); - recognizer_thread.join(); std::string result = recognizer_ptr->GetFinalResult(); - recognizer_ptr->Reset(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -107,6 +100,7 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); @@ -115,9 +109,11 @@ int main(int argc, char* argv[]) { ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/common/frontend/compute_fbank_main.cc b/speechx/speechx/common/frontend/compute_fbank_main.cc index d7d5165c..e022207d 100644 --- a/speechx/speechx/common/frontend/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/compute_fbank_main.cc @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc index e6ac3c23..c166bd64 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -20,10 +20,9 @@ using kaldi::BaseFloat; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } @@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts, void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms + // feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + if (cache_.empty()) return false; // read from cache *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; return true; } @@ -73,23 +63,15 @@ bool FeatureCache::Compute() { kaldi::Timer timer; int32 num_chunk = feature.size() / dim_; - nframe_ += num_chunk; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; vector feature_chunk(feature.data() + start, feature.data() + start + dim_); - - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - // cache full, wait - ready_feed_condition_.wait(lock); - } - // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; } VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " @@ -97,4 +79,4 @@ bool FeatureCache::Compute() { return true; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/common/frontend/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h index 51816a1d..b87612d6 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -19,16 +19,10 @@ namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + std::unique_lock lock(mutex_); LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); - base_extractor_->SetFinished(); - // read the last chunk data Compute(); - // ready_feed_condition_.notify_one(); + base_extractor_->SetFinished(); LOG(INFO) << "compute last feats done."; } @@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface { int32 dim_; size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms - std::vector remained_feature_; std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache); diff --git a/speechx/speechx/common/frontend/feature_pipeline.cc b/speechx/speechx/common/frontend/feature_pipeline.cc index 34e55a10..f37b4180 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/feature_pipeline.cc @@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/common/frontend/feature_pipeline.h b/speechx/speechx/common/frontend/feature_pipeline.h index ea7e2bba..c9a649fd 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.h +++ b/speechx/speechx/common/frontend/feature_pipeline.h @@ -39,7 +39,6 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; knf::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() { From 5042a1686a66679dabb5b25e0b42824db0762ad5 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 1 Feb 2023 17:19:52 +0800 Subject: [PATCH 08/50] [speechx] add batch recognizer decode. (#2866) * add recognizer_batch --- speechx/speechx/asr/nnet/u2_nnet.cc | 35 ++-- speechx/speechx/asr/nnet/u2_nnet.h | 4 +- speechx/speechx/asr/recognizer/CMakeLists.txt | 1 + .../speechx/asr/recognizer/u2_recognizer.cc | 30 ++- .../speechx/asr/recognizer/u2_recognizer.h | 4 +- .../recognizer/u2_recognizer_batch_main.cc | 185 ++++++++++++++++++ speechx/speechx/common/base/common.h | 2 + 7 files changed, 242 insertions(+), 19 deletions(-) create mode 100644 speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc diff --git a/speechx/speechx/asr/nnet/u2_nnet.cc b/speechx/speechx/asr/nnet/u2_nnet.cc index e3277a38..0795c836 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.cc +++ b/speechx/speechx/asr/nnet/u2_nnet.cc @@ -118,27 +118,38 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) { // shallow copy U2Nnet::U2Nnet(const U2Nnet& other) { // copy meta - right_context_ = other.right_context_; - subsampling_rate_ = other.subsampling_rate_; - sos_ = other.sos_; - eos_ = other.eos_; - is_bidecoder_ = other.is_bidecoder_; chunk_size_ = other.chunk_size_; num_left_chunks_ = other.num_left_chunks_; - - forward_encoder_chunk_ = other.forward_encoder_chunk_; - forward_attention_decoder_ = other.forward_attention_decoder_; - ctc_activation_ = other.ctc_activation_; - offset_ = other.offset_; // copy model ptr - model_ = other.model_; + model_ = other.model_->Clone(); + ctc_activation_ = model_->Function("ctc_activation"); + subsampling_rate_ = model_->Attribute("subsampling_rate"); + right_context_ = model_->Attribute("right_context"); + sos_ = model_->Attribute("sos_symbol"); + eos_ = model_->Attribute("eos_symbol"); + is_bidecoder_ = model_->Attribute("is_bidirectional_decoder"); + + forward_encoder_chunk_ = model_->Function("forward_encoder_chunk"); + forward_attention_decoder_ = model_->Function("forward_attention_decoder"); + ctc_activation_ = model_->Function("ctc_activation"); + CHECK(forward_encoder_chunk_.IsValid()); + CHECK(forward_attention_decoder_.IsValid()); + CHECK(ctc_activation_.IsValid()); + + LOG(INFO) << "Paddle Model Info: "; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl; + // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Clone() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); diff --git a/speechx/speechx/asr/nnet/u2_nnet.h b/speechx/speechx/asr/nnet/u2_nnet.h index 127d84db..35a15707 100644 --- a/speechx/speechx/asr/nnet/u2_nnet.h +++ b/speechx/speechx/asr/nnet/u2_nnet.h @@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase { num_left_chunks_ = num_left_chunks; } - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Clone() const = 0; protected: virtual void ForwardEncoderChunkImpl( @@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Clone() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 8f9117e4..f28c5fea 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -10,6 +10,7 @@ target_link_libraries(recognizer PUBLIC decoder) set(TEST_BINS u2_recognizer_main u2_recognizer_thread_main + u2_recognizer_batch_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index 0c5a8941..30595d79 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -43,12 +43,34 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) input_finished_ = false; num_frames_ = 0; result_.clear(); +} + +U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet) + : opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + std::shared_ptr feature_pipeline = + std::make_shared(feature_opts); + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); + decodable_.reset(new Decodable(nnet_producer_, am_scale)); + + CHECK_NE(resource.vocab_path, ""); + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + + unit_table_ = decoder_->VocabTable(); + symbol_table_ = unit_table_; + global_frame_offset_ = 0; + input_finished_ = false; + num_frames_ = 0; + result_.clear(); } U2Recognizer::~U2Recognizer() { - SetInputFinished(); - WaitDecodeFinished(); + SetInputFinished(); + WaitDecodeFinished(); } void U2Recognizer::WaitDecodeFinished() { @@ -97,8 +119,8 @@ void U2Recognizer::RunDecoderSearchInternal() { void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; nnet_producer_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.size() - << " samples."; + VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " + << waves.size() << " samples."; } void U2Recognizer::Decode() { diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 57f2c9c5..5d628e3a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,6 +112,8 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); + explicit U2Recognizer(const U2RecognizerResource& resource, + std::shared_ptr nnet); ~U2Recognizer(); void InitDecoder(); void ResetContinuousDecoding(); @@ -143,7 +145,7 @@ class U2Recognizer { void AttentionRescoring(); private: - static void RunDecoderSearch(U2Recognizer *me); + static void RunDecoderSearch(U2Recognizer* me); void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc new file mode 100644 index 00000000..709e5aa6 --- /dev/null +++ b/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/u2_recognizer.h" +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(const ppspeech::U2RecognizerResource& resource, + std::shared_ptr nnet, + std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + std::shared_ptr recognizer_ptr = + std::make_shared(resource, nnet); + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + recognizer_ptr->InitDecoder(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + recognizer_ptr->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer_ptr->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); + + std::string result = recognizer_ptr->GetFinalResult(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + recognizer_ptr->WaitFinished(); + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + std::shared_ptr nnet( + new ppspeech::U2Nnet(resource.model_opts)); + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + resource, + nnet->Clone(), + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/speechx/speechx/common/base/common.h b/speechx/speechx/common/base/common.h index 2a066ee6..06fcd9fd 100644 --- a/speechx/speechx/common/base/common.h +++ b/speechx/speechx/common/base/common.h @@ -42,6 +42,8 @@ #include #include #include +#include +#include #include "base/basic_types.h" #include "base/flags.h" From 21183d48b63009e49729da6e6864ad666c09ae4b Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Tue, 7 Feb 2023 16:46:45 +0800 Subject: [PATCH 09/50] add wfst decoder (#2886) --- speechx/CMakeLists.txt | 2 +- speechx/speechx/asr/decoder/CMakeLists.txt | 2 + .../decoder/ctc_prefix_beam_search_decoder.h | 3 +- .../speechx/asr/decoder/ctc_tlg_decoder.cc | 44 ++++++++++- speechx/speechx/asr/decoder/ctc_tlg_decoder.h | 37 +++++++-- .../asr/decoder/ctc_tlg_decoder_main.cc | 77 ++++--------------- speechx/speechx/asr/decoder/decoder_itf.h | 10 ++- speechx/speechx/asr/decoder/param.h | 4 +- speechx/speechx/asr/nnet/decodable.h | 2 - speechx/speechx/asr/nnet/nnet_producer.cc | 16 ++-- speechx/speechx/asr/nnet/nnet_producer.h | 10 +-- .../speechx/asr/recognizer/u2_recognizer.cc | 46 ++++++----- .../speechx/asr/recognizer/u2_recognizer.h | 16 +++- 13 files changed, 154 insertions(+), 115 deletions(-) diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index e24744d6..d056ebbc 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch}) # compiler option # Keep the same with openfst, -fPIC or -fpic -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index b2f50708..07adda95 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -1,6 +1,7 @@ set(srcs) list(APPEND srcs ctc_prefix_beam_search_decoder.cc + ctc_tlg_decoder.cc ) add_library(decoder STATIC ${srcs}) @@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) # test set(TEST_BINS ctc_prefix_beam_search_decoder_main + ctc_tlg_decoder_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h index 5013246a..3fe1944c 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h @@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase { void FinalizeSearch(); - const std::shared_ptr VocabTable() const { + const std::shared_ptr WordSymbolTable() const override { return unit_table_; } @@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase { } const std::vector>& Times() const { return times_; } - protected: std::string GetBestPath() override; std::vector> GetNBestPath() override; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc index 2c2b6d3c..ca7d65c8 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc @@ -15,7 +15,7 @@ #include "decoder/ctc_tlg_decoder.h" namespace ppspeech { -TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { fst_.reset(fst::Fst::Read(opts.fst_path)); CHECK(fst_ != nullptr); @@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() { return words; } +void TLGDecoder::FinalizeSearch() { + decoder_->FinalizeDecoding(); + kaldi::CompactLattice clat; + decoder_->GetLattice(&clat, true); + kaldi::Lattice lat, nbest_lat; + fst::ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, opts_.nbest); + std::vector nbest_lats; + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + + hypotheses_.clear(); + hypotheses_.reserve(nbest_lats.size()); + likelihood_.clear(); + likelihood_.reserve(nbest_lats.size()); + times_.clear(); + times_.reserve(nbest_lats.size()); + for (auto lat : nbest_lats) { + kaldi::LatticeWeight weight; + std::vector hypothese; + std::vector time; + std::vector alignment; + std::vector words_id; + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + int idx = 0; + for (; idx < alignment.size() - 1; ++idx) { + if (alignment[idx] == 0) continue; + if (alignment[idx] != alignment[idx + 1]) { + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + } + } + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + hypotheses_.push_back(hypothese); + times_.push_back(time); + olabels.push_back(words_id); + likelihood_.push_back(-(weight.Value2() + weight.Value1())); + } +} + std::string TLGDecoder::GetFinalBestPath() { if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); } - - decoder_->FinalizeDecoding(); kaldi::Lattice lat; kaldi::LatticeWeight weight; std::vector alignment; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h b/speechx/speechx/asr/decoder/ctc_tlg_decoder.h index 8be69dad..1ea6d634 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder.h @@ -19,9 +19,8 @@ #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" - -DECLARE_string(graph_path); DECLARE_string(word_symbol_table); +DECLARE_string(graph_path); DECLARE_int32(max_active); DECLARE_double(beam); DECLARE_double(lattice_beam); @@ -33,6 +32,9 @@ struct TLGDecoderOptions { // todo remove later, add into decode resource std::string word_symbol_table; std::string fst_path; + int nbest; + + TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {} static TLGDecoderOptions InitFromFlags() { TLGDecoderOptions decoder_opts; @@ -44,6 +46,7 @@ struct TLGDecoderOptions { decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + // decoder_opts.nbest = FLAGS_lattice_nbest; LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; @@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase { explicit TLGDecoder(TLGDecoderOptions opts); ~TLGDecoder() = default; - void InitDecoder(); - void Reset(); + void InitDecoder() override; + void Reset() override; void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable) override; void Decode(); std::string GetFinalBestPath() override; std::string GetPartialResult() override; + const std::shared_ptr WordSymbolTable() const override { + return word_symbol_table_; + } + int DecodeLikelihoods(const std::vector>& probs, const std::vector& nbest_words); + void FinalizeSearch() override; + const std::vector>& Inputs() const override { + return hypotheses_; + } + const std::vector>& Outputs() const override { + return olabels; + } // outputs_; } + const std::vector& Likelihood() const override { + return likelihood_; + } + const std::vector>& Times() const override { + return times_; + } + protected: std::string GetBestPath() override { CHECK(false); @@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase { private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); + std::vector> hypotheses_; + std::vector> olabels; + std::vector likelihood_; + std::vector> times_; + std::shared_ptr decoder_; std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; + TLGDecoderOptions opts_; }; diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc index e9bd8a3f..148ee15e 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc @@ -14,16 +14,16 @@ // todo refactor, repalce with gtest -#include "base/common.h" #include "decoder/ctc_tlg_decoder.h" +#include "base/common.h" #include "decoder/param.h" -#include "frontend/audio/data_cache.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" +#include "nnet/nnet_producer.h" -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); +DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -39,8 +39,8 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); + kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader( + FLAGS_nnet_prob_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); int32 num_done = 0, num_err = 0; @@ -53,66 +53,19 @@ int main(int argc, char* argv[]) { ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); + std::shared_ptr nnet_producer = + std::make_shared(nullptr); std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; + new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); decoder.InitDecoder(); kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } + + for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) { + string utt = nnet_prob_reader.Key(); + kaldi::Matrix prob = nnet_prob_reader.Value(); + decodable->Acceptlikelihood(prob); + decoder.AdvanceDecode(decodable); std::string result; result = decoder.GetFinalBestPath(); decodable->Reset(); diff --git a/speechx/speechx/asr/decoder/decoder_itf.h b/speechx/speechx/asr/decoder/decoder_itf.h index 2289b317..cb7717e8 100644 --- a/speechx/speechx/asr/decoder/decoder_itf.h +++ b/speechx/speechx/asr/decoder/decoder_itf.h @@ -1,4 +1,3 @@ - // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +15,7 @@ #pragma once #include "base/common.h" +#include "fst/symbol-table.h" #include "kaldi/decoder/decodable-itf.h" namespace ppspeech { @@ -41,6 +41,14 @@ class DecoderInterface { virtual std::string GetPartialResult() = 0; + virtual const std::shared_ptr WordSymbolTable() const = 0; + virtual void FinalizeSearch() = 0; + + virtual const std::vector>& Inputs() const = 0; + virtual const std::vector>& Outputs() const = 0; + virtual const std::vector& Likelihood() const = 0; + virtual const std::vector>& Times() const = 0; + protected: // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; diff --git a/speechx/speechx/asr/decoder/param.h b/speechx/speechx/asr/decoder/param.h index cad6dbd8..83e2c7fb 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/speechx/speechx/asr/decoder/param.h @@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "", "decoder graph"); +DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); diff --git a/speechx/speechx/asr/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h index 44c7a0c3..c1dbb4b8 100644 --- a/speechx/speechx/asr/nnet/decodable.h +++ b/speechx/speechx/asr/nnet/decodable.h @@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface { explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); - // void Init(DecodableOpts config); - // nnet logprob output, used by wfst virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index b83b5976..29daa709 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -23,25 +23,25 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) : nnet_(nnet), frontend_(frontend) { - abort_ = false; - Reset(); - thread_ = std::thread(RunNnetEvaluation, this); - } + abort_ = false; + Reset(); + if (nnet_ != nullptr) thread_ = std::thread(RunNnetEvaluation, this); +} void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); condition_variable_.notify_one(); } -void NnetProducer::UnLock() { +void NnetProducer::WaitProduce() { std::unique_lock lock(read_mutex_); while (frontend_->IsFinished() == false && cache_.empty()) { - condition_read_ready_.wait(lock); + condition_read_ready_.wait(lock); } return; } -void NnetProducer::RunNnetEvaluation(NnetProducer *me) { +void NnetProducer::RunNnetEvaluation(NnetProducer* me) { me->RunNnetEvaluationInteral(); } @@ -55,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() { result = Compute(); } while (result); if (frontend_->IsFinished() == true) { - if (cache_.empty()) finished_ = true; + if (cache_.empty()) finished_ = true; } } LOG(INFO) << "NnetEvaluationInteral exit"; diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index 14c74d04..9eb3a4f7 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -34,9 +34,9 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); bool ReadandCompute(std::vector* nnet_prob); - static void RunNnetEvaluation(NnetProducer *me); + static void RunNnetEvaluation(NnetProducer* me); void RunNnetEvaluationInteral(); - void UnLock(); + void WaitProduce(); void Wait() { abort_ = true; @@ -56,12 +56,12 @@ class NnetProducer { bool IsFinished() const { return finished_; } ~NnetProducer() { - if (thread_.joinable()) thread_.join(); + if (thread_.joinable()) thread_.join(); } void Reset() { - frontend_->Reset(); - nnet_->Reset(); + if (frontend_ != NULL) frontend_->Reset(); + if (nnet_ != NULL) nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); finished_ = false; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index 30595d79..f31ceb3b 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { + LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path; + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + } else { + decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); + } - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; + symbol_table_ = decoder_->WordSymbolTable(); global_frame_offset_ = 0; input_finished_ = false; @@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource, decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { + decoder_.reset(new CTCPrefixBeamSearch( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + } else { + decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); + } - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; + symbol_table_ = decoder_->WordSymbolTable(); global_frame_offset_ = 0; input_finished_ = false; @@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { void U2Recognizer::RunDecoderSearchInternal() { LOG(INFO) << "DecoderSearchInteral begin"; while (!nnet_producer_->IsFinished()) { - nnet_producer_->UnLock(); + nnet_producer_->WaitProduce(); decoder_->AdvanceDecode(decodable_); } - Decode(); + decoder_->AdvanceDecode(decodable_); + UpdateResult(false); LOG(INFO) << "DecoderSearchInteral exit"; } @@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) { const auto& times = decoder_->Times(); result_.clear(); - CHECK_EQ(hypotheses.size(), likelihood.size()); + CHECK_EQ(inputs.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { const std::vector& hypothesis = hypotheses[i]; @@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) { path.score = likelihood[i]; for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); - // A detailed explanation of this if-else branch can be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (decoder_->Type() == kWfstBeamSearch) { - path.sentence += (" " + word); - } else { - path.sentence += (word); - } + // path.sentence += (" " + word); // todo SmileGoat: add blank + // processor + path.sentence += word; // todo SmileGoat: add blank processor } // TimeStamp is only supported in final result @@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) { // various FST operations when building the decoding graph. So here we // use time stamp of the input(e2e model unit), which is more accurate, // and it requires the symbol table of the e2e model used in training. - if (unit_table_ != nullptr && finish) { + if (symbol_table_ != nullptr && finish) { int offset = global_frame_offset_ * FrameShiftInMs(); const std::vector& input = inputs[i]; @@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) { CHECK_EQ(input.size(), time_stamp.size()); for (size_t j = 0; j < input.size(); j++) { - std::string word = unit_table_->Find(input[j]); + std::string word = symbol_table_->Find(input[j]); int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 @@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) { void U2Recognizer::AttentionRescoring() { decoder_->FinalizeSearch(); - UpdateResult(true); + UpdateResult(false); // No need to do rescoring if (0.0 == opts_.decoder_opts.rescoring_weight) { diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 5d628e3a..889da85b 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -17,6 +17,7 @@ #include "decoder/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" #include "decoder/decoder_itf.h" #include "frontend/feature_pipeline.h" #include "fst/fstlib.h" @@ -33,6 +34,8 @@ DECLARE_int32(blank); DECLARE_double(acoustic_scale); DECLARE_string(vocab_path); +DECLARE_string(word_symbol_table); +// DECLARE_string(fst_path); namespace ppspeech { @@ -59,6 +62,7 @@ struct DecodeOptions { // CtcEndpointConfig ctc_endpoint_opts; CTCBeamSearchOptions ctc_prefix_search_opts{}; + TLGDecoderOptions tlg_decoder_opts{}; static DecodeOptions InitFromFlags() { DecodeOptions decoder_opts; @@ -70,6 +74,13 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + // decoder_opts.tlg_decoder_opts.fst_path = "";//FLAGS_fst_path; + // decoder_opts.tlg_decoder_opts.word_symbol_table = + // FLAGS_word_symbol_table; + // decoder_opts.tlg_decoder_opts.nbest = FLAGS_nbest; + decoder_opts.tlg_decoder_opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size; LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks; LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight; @@ -113,7 +124,7 @@ class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); explicit U2Recognizer(const U2RecognizerResource& resource, - std::shared_ptr nnet); + std::shared_ptr nnet); ~U2Recognizer(); void InitDecoder(); void ResetContinuousDecoding(); @@ -154,10 +165,9 @@ class U2Recognizer { std::shared_ptr nnet_producer_; std::shared_ptr decodable_; - std::unique_ptr decoder_; + std::unique_ptr decoder_; // e2e unit symbol table - std::shared_ptr unit_table_ = nullptr; std::shared_ptr symbol_table_ = nullptr; std::vector result_; From 8e1b4cd51301d667a0dcf73d9e945924380134f1 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 8 Feb 2023 09:49:07 +0800 Subject: [PATCH 10/50] [engine] rename speechx (#2892) * rename speechx * fix wfst decode error * replace reset with make_unique --- .pre-commit-config.yaml | 10 +- README.md | 2 +- {speechx => runtime}/.clang-format | 0 {speechx => runtime}/.gitignore | 1 + {speechx => runtime}/CMakeLists.txt | 8 +- {speechx => runtime}/README.md | 5 +- {speechx => runtime}/build.sh | 0 .../cmake/EnableCMP0048.cmake | 0 .../cmake/FindGFortranLibs.cmake | 0 {speechx => runtime}/cmake/absl.cmake | 0 {speechx => runtime}/cmake/boost.cmake | 0 {speechx => runtime}/cmake/eigen.cmake | 0 {speechx => runtime}/cmake/gflags.cmake | 4 +- {speechx => runtime}/cmake/glog.cmake | 2 +- {speechx => runtime}/cmake/gtest.cmake | 4 +- {speechx => runtime}/cmake/kenlm.cmake | 0 {speechx => runtime}/cmake/libsndfile.cmake | 0 {speechx => runtime}/cmake/openblas.cmake | 0 {speechx => runtime}/cmake/openfst.cmake | 6 +- .../cmake/paddleinference.cmake | 0 {speechx => runtime}/cmake/system.cmake | 0 {speechx => runtime}/docker/.gitkeep | 0 .../speechx => runtime/engine}/CMakeLists.txt | 0 .../engine}/asr/CMakeLists.txt | 0 .../engine}/asr/decoder/CMakeLists.txt | 0 .../engine}/asr/decoder/common.h | 0 .../engine}/asr/decoder/ctc_beam_search_opt.h | 0 .../decoder/ctc_prefix_beam_search_decoder.cc | 4 +- .../decoder/ctc_prefix_beam_search_decoder.h | 0 .../ctc_prefix_beam_search_decoder_main.cc | 0 .../decoder/ctc_prefix_beam_search_score.h | 0 .../engine}/asr/decoder/ctc_tlg_decoder.cc | 7 +- .../engine}/asr/decoder/ctc_tlg_decoder.h | 10 +- .../asr/decoder/ctc_tlg_decoder_main.cc | 0 .../engine}/asr/decoder/decoder_itf.h | 0 .../engine}/asr/decoder/param.h | 1 - .../engine}/asr/nnet/CMakeLists.txt | 0 .../engine}/asr/nnet/decodable.cc | 0 .../engine}/asr/nnet/decodable.h | 0 .../engine}/asr/nnet/nnet_itf.h | 0 .../engine}/asr/nnet/nnet_producer.cc | 0 .../engine}/asr/nnet/nnet_producer.h | 0 .../engine}/asr/nnet/u2_nnet.cc | 0 .../engine}/asr/nnet/u2_nnet.h | 0 .../engine}/asr/nnet/u2_nnet_main.cc | 8 +- .../engine}/asr/nnet/u2_nnet_thread_main.cc | 6 +- .../engine}/asr/recognizer/CMakeLists.txt | 0 .../engine}/asr/recognizer/u2_recognizer.cc | 10 +- .../engine}/asr/recognizer/u2_recognizer.h | 8 +- .../recognizer/u2_recognizer_batch_main.cc | 0 .../asr/recognizer/u2_recognizer_main.cc | 2 +- .../recognizer/u2_recognizer_thread_main.cc | 2 +- .../engine}/asr/server/CMakeLists.txt | 0 .../asr/server/websocket/CMakeLists.txt | 0 .../asr/server/websocket/websocket_client.cc | 0 .../asr/server/websocket/websocket_client.h | 0 .../server/websocket/websocket_client_main.cc | 2 +- .../asr/server/websocket/websocket_server.cc | 0 .../asr/server/websocket/websocket_server.h | 0 .../server/websocket/websocket_server_main.cc | 2 +- .../engine}/codelab/CMakeLists.txt | 0 .../engine}/codelab/README.md | 0 .../engine}/codelab/glog/CMakeLists.txt | 0 .../engine}/codelab/glog/README.md | 0 .../codelab/glog/glog_logtostderr_main.cc | 0 .../engine}/codelab/glog/glog_main.cc | 0 .../engine}/common/CMakeLists.txt | 0 .../engine}/common/base/basic_types.h | 2 +- .../engine}/common/base/common.h | 4 +- .../engine}/common/base/flags.h | 0 .../engine}/common/base/log.h | 0 .../engine}/common/base/macros.h | 0 .../engine}/common/base/safe_queue.h | 0 .../engine/common/base/safe_queue_inl.h | 0 .../engine}/common/base/thread_pool.h | 0 .../engine}/common/frontend/CMakeLists.txt | 0 .../engine}/common/frontend/assembler.cc | 4 +- .../engine}/common/frontend/assembler.h | 0 .../engine}/common/frontend/audio_cache.cc | 0 .../engine}/common/frontend/audio_cache.h | 0 .../engine}/common/frontend/cmvn.cc | 22 +- .../engine}/common/frontend/cmvn.h | 0 .../common/frontend/compute_fbank_main.cc | 0 .../compute_linear_spectrogram_main.cc | 0 .../engine}/common/frontend/data_cache.h | 0 .../engine}/common/frontend/db_norm.cc | 0 .../engine}/common/frontend/db_norm.h | 0 .../engine}/common/frontend/fbank.cc | 0 .../engine}/common/frontend/fbank.h | 2 +- .../engine}/common/frontend/feature-fbank.cc | 0 .../engine}/common/frontend/feature-fbank.h | 0 .../common/frontend/feature-functions.cc | 0 .../common/frontend/feature-functions.h | 0 .../engine}/common/frontend/feature-window.cc | 0 .../engine}/common/frontend/feature-window.h | 0 .../engine}/common/frontend/feature_cache.cc | 2 +- .../engine}/common/frontend/feature_cache.h | 2 +- .../engine}/common/frontend/feature_common.h | 2 +- .../common/frontend/feature_common_inl.h | 0 .../common/frontend/feature_pipeline.cc | 0 .../common/frontend/feature_pipeline.h | 2 +- .../engine}/common/frontend/fftsg.c | 0 .../engine}/common/frontend/frontend_itf.h | 0 .../common/frontend/linear_spectrogram.cc | 0 .../common/frontend/linear_spectrogram.h | 0 .../common/frontend/mel-computations.cc | 0 .../common/frontend/mel-computations.h | 0 .../engine}/common/frontend/mfcc.cc | 0 .../engine}/common/frontend/mfcc.h | 0 .../engine}/common/frontend/normalizer.h | 0 .../engine}/common/frontend/rfft.cc | 0 .../engine}/common/frontend/rfft.h | 0 runtime/engine/common/frontend/wave-reader.cc | 376 +++++ runtime/engine/common/frontend/wave-reader.h | 248 ++++ .../engine}/common/matrix/CMakeLists.txt | 0 .../engine}/common/matrix/kaldi-matrix-inl.h | 0 .../engine}/common/matrix/kaldi-matrix.cc | 0 .../engine}/common/matrix/kaldi-matrix.h | 0 .../engine}/common/matrix/kaldi-vector-inl.h | 0 .../engine}/common/matrix/kaldi-vector.cc | 0 .../engine}/common/matrix/kaldi-vector.h | 0 .../engine}/common/matrix/matrix-common.h | 0 .../engine}/common/utils/CMakeLists.txt | 0 .../engine}/common/utils/file_utils.cc | 0 .../engine}/common/utils/file_utils.h | 0 .../engine}/common/utils/math.cc | 2 +- .../engine}/common/utils/math.h | 0 runtime/engine/common/utils/picojson.h | 1230 +++++++++++++++++ .../engine}/common/utils/strings.cc | 14 +- .../engine}/common/utils/strings.h | 8 +- .../engine}/common/utils/strings_test.cc | 10 +- .../engine}/kaldi/CMakeLists.txt | 0 .../engine}/kaldi/base/CMakeLists.txt | 0 .../engine}/kaldi/base/io-funcs-inl.h | 0 .../engine}/kaldi/base/io-funcs.cc | 0 .../engine}/kaldi/base/io-funcs.h | 0 .../engine}/kaldi/base/kaldi-common.h | 0 .../engine}/kaldi/base/kaldi-error.cc | 0 .../engine}/kaldi/base/kaldi-error.h | 0 .../engine}/kaldi/base/kaldi-math.cc | 0 .../engine}/kaldi/base/kaldi-math.h | 0 .../engine}/kaldi/base/kaldi-types.h | 0 .../engine}/kaldi/base/kaldi-utils.cc | 0 .../engine}/kaldi/base/kaldi-utils.h | 0 .../engine}/kaldi/base/timer.cc | 0 .../engine}/kaldi/base/timer.h | 0 .../engine}/kaldi/base/version.h | 0 .../engine}/kaldi/decoder/CMakeLists.txt | 0 .../engine}/kaldi/decoder/decodable-itf.h | 0 .../kaldi/decoder/lattice-faster-decoder.cc | 0 .../kaldi/decoder/lattice-faster-decoder.h | 0 .../decoder/lattice-faster-online-decoder.cc | 0 .../decoder/lattice-faster-online-decoder.h | 0 .../engine}/kaldi/fstbin/CMakeLists.txt | 0 .../engine}/kaldi/fstbin/fstaddselfloops.cc | 0 .../kaldi/fstbin/fstdeterminizestar.cc | 0 .../engine}/kaldi/fstbin/fstisstochastic.cc | 0 .../kaldi/fstbin/fstminimizeencoded.cc | 0 .../engine}/kaldi/fstbin/fsttablecompose.cc | 0 .../engine}/kaldi/fstext/CMakeLists.txt | 0 .../kaldi/fstext/determinize-lattice-inl.h | 0 .../kaldi/fstext/determinize-lattice.h | 0 .../kaldi/fstext/determinize-star-inl.h | 0 .../engine}/kaldi/fstext/determinize-star.h | 0 .../engine}/kaldi/fstext/fstext-lib.h | 0 .../engine}/kaldi/fstext/fstext-utils-inl.h | 0 .../engine}/kaldi/fstext/fstext-utils.h | 0 .../engine}/kaldi/fstext/kaldi-fst-io-inl.h | 0 .../engine}/kaldi/fstext/kaldi-fst-io.cc | 0 .../engine}/kaldi/fstext/kaldi-fst-io.h | 0 .../engine}/kaldi/fstext/lattice-utils-inl.h | 0 .../engine}/kaldi/fstext/lattice-utils.h | 0 .../engine}/kaldi/fstext/lattice-weight.h | 0 .../kaldi/fstext/pre-determinize-inl.h | 0 .../engine}/kaldi/fstext/pre-determinize.h | 0 .../kaldi/fstext/remove-eps-local-inl.h | 0 .../engine}/kaldi/fstext/remove-eps-local.h | 0 .../engine}/kaldi/fstext/table-matcher.h | 0 .../engine}/kaldi/lat/CMakeLists.txt | 0 .../kaldi/lat/determinize-lattice-pruned.cc | 0 .../kaldi/lat/determinize-lattice-pruned.h | 0 .../engine}/kaldi/lat/kaldi-lattice.cc | 0 .../engine}/kaldi/lat/kaldi-lattice.h | 0 .../engine}/kaldi/lat/lattice-functions.cc | 0 .../engine}/kaldi/lat/lattice-functions.h | 0 .../engine}/kaldi/lm/CMakeLists.txt | 0 .../engine}/kaldi/lm/arpa-file-parser.cc | 0 .../engine}/kaldi/lm/arpa-file-parser.h | 0 .../engine}/kaldi/lm/arpa-lm-compiler.cc | 0 .../engine}/kaldi/lm/arpa-lm-compiler.h | 0 .../engine}/kaldi/lmbin/CMakeLists.txt | 0 .../engine}/kaldi/lmbin/arpa2fst.cc | 0 .../engine}/kaldi/util/CMakeLists.txt | 0 .../engine}/kaldi/util/basic-filebuf.h | 0 .../engine}/kaldi/util/common-utils.h | 0 .../kaldi/util/const-integer-set-inl.h | 0 .../engine}/kaldi/util/const-integer-set.h | 0 .../engine}/kaldi/util/edit-distance-inl.h | 0 .../engine}/kaldi/util/edit-distance.h | 0 .../engine}/kaldi/util/hash-list-inl.h | 0 .../engine}/kaldi/util/hash-list.h | 0 .../engine}/kaldi/util/kaldi-cygwin-io-inl.h | 0 .../engine}/kaldi/util/kaldi-holder-inl.h | 0 .../engine}/kaldi/util/kaldi-holder.cc | 0 .../engine}/kaldi/util/kaldi-holder.h | 0 .../engine}/kaldi/util/kaldi-io-inl.h | 0 .../engine}/kaldi/util/kaldi-io.cc | 0 .../engine}/kaldi/util/kaldi-io.h | 0 .../engine}/kaldi/util/kaldi-pipebuf.h | 0 .../engine}/kaldi/util/kaldi-semaphore.cc | 0 .../engine}/kaldi/util/kaldi-semaphore.h | 0 .../engine}/kaldi/util/kaldi-table-inl.h | 0 .../engine}/kaldi/util/kaldi-table.cc | 0 .../engine}/kaldi/util/kaldi-table.h | 0 .../engine}/kaldi/util/kaldi-thread.cc | 0 .../engine}/kaldi/util/kaldi-thread.h | 0 .../engine}/kaldi/util/options-itf.h | 0 .../engine}/kaldi/util/parse-options.cc | 0 .../engine}/kaldi/util/parse-options.h | 0 .../engine}/kaldi/util/simple-io-funcs.cc | 0 .../engine}/kaldi/util/simple-io-funcs.h | 0 .../engine}/kaldi/util/simple-options.cc | 0 .../engine}/kaldi/util/simple-options.h | 0 .../engine}/kaldi/util/stl-utils.h | 0 .../engine}/kaldi/util/table-types.h | 0 .../engine}/kaldi/util/text-utils.cc | 0 .../engine}/kaldi/util/text-utils.h | 0 {speechx => runtime}/examples/.gitignore | 1 + {speechx => runtime}/examples/README.md | 0 .../examples/codelab/README.md | 0 .../examples/codelab/decoder/.gitignore | 0 .../examples/codelab/decoder/README.md | 0 .../examples/codelab/decoder/path.sh | 0 .../examples/codelab/decoder/run.sh | 0 .../examples/codelab/decoder/valgrind.sh | 0 .../examples/codelab/feat/.gitignore | 0 .../examples/codelab/feat/README.md | 0 .../examples/codelab/feat/path.sh | 0 .../examples/codelab/feat/run.sh | 0 .../examples/codelab/feat/valgrind.sh | 0 .../examples/codelab/nnet/.gitignore | 0 .../examples/codelab/nnet/README.md | 0 .../examples/codelab/nnet/path.sh | 0 .../examples/codelab/nnet/run.sh | 0 .../examples/codelab/nnet/valgrind.sh | 0 .../examples/codelab/u2/.gitignore | 0 .../examples/codelab/u2/README.md | 0 .../examples/codelab/u2/local/decode.sh | 0 .../examples/codelab/u2/local/feat.sh | 0 .../examples/codelab/u2/local/nnet.sh | 0 .../examples/codelab/u2/local/recognizer.sh | 0 .../examples/codelab/u2/path.sh | 0 .../examples/codelab/u2/run.sh | 0 .../examples/codelab/u2/utils | 0 .../examples/custom_asr/README.md | 0 .../local/compile_lexicon_token_fst.sh | 0 .../custom_asr/local/mk_slot_graph.sh | 0 .../custom_asr/local/mk_tlg_with_slot.sh | 0 .../custom_asr/local/train_lm_with_slot.sh | 0 .../examples/custom_asr/path.sh | 8 +- .../examples/custom_asr/run.sh | 0 .../examples/custom_asr/utils | 0 .../examples/text_lm/.gitignore | 0 .../examples/text_lm/README.md | 0 .../examples/text_lm/local/data/chars.dic | 0 .../examples/text_lm/local/data/words.dic | 0 .../examples/text_lm/local/mmseg.py | 0 runtime/examples/text_lm/path.sh | 4 + {speechx => runtime}/examples/text_lm/run.sh | 0 {speechx => runtime}/examples/text_lm/utils | 0 .../examples/u2pp_ol/README.md | 0 .../examples/u2pp_ol/wenetspeech/.gitignore | 0 .../examples/u2pp_ol/wenetspeech/README.md | 6 +- .../examples/u2pp_ol/wenetspeech/RESULTS.md | 0 .../wenetspeech/local/aishell_train_lms.sh | 0 .../u2pp_ol/wenetspeech/local/decode.sh | 0 .../u2pp_ol/wenetspeech/local/feat.sh | 0 .../u2pp_ol/wenetspeech/local/nnet.sh | 0 .../u2pp_ol/wenetspeech/local/recognizer.sh | 0 .../wenetspeech/local/recognizer_quant.sh | 0 .../wenetspeech/local/run_build_tlg.sh | 84 ++ .../u2pp_ol/wenetspeech/local/split_data.sh | 0 runtime/examples/u2pp_ol/wenetspeech/path.sh | 18 + .../examples/u2pp_ol/wenetspeech/run.sh | 0 .../examples/u2pp_ol/wenetspeech/utils | 0 {speechx => runtime}/patch/CPPLINT.cfg | 0 {speechx => runtime}/patch/README.md | 0 .../patch/openfst/src/include/fst/flags.h | 0 .../patch/openfst/src/include/fst/log.h | 0 .../patch/openfst/src/lib/flags.cc | 0 {speechx => runtime}/tools/clang-format.sh | 0 {speechx => runtime}/tools/setup_valgrind.sh | 0 {speechx => runtime}/tools/venv.sh | 0 speechx/examples/text_lm/path.sh | 4 - speechx/examples/u2pp_ol/wenetspeech/path.sh | 18 - .../speechx/common/frontend/wave-reader.cc | 387 ------ speechx/speechx/common/frontend/wave-reader.h | 248 ---- speechx/speechx/common/utils/picojson.h | 1202 ---------------- 298 files changed, 2061 insertions(+), 1953 deletions(-) rename {speechx => runtime}/.clang-format (100%) rename {speechx => runtime}/.gitignore (65%) rename {speechx => runtime}/CMakeLists.txt (92%) rename {speechx => runtime}/README.md (96%) rename {speechx => runtime}/build.sh (100%) rename {speechx => runtime}/cmake/EnableCMP0048.cmake (100%) rename {speechx => runtime}/cmake/FindGFortranLibs.cmake (100%) rename {speechx => runtime}/cmake/absl.cmake (100%) rename {speechx => runtime}/cmake/boost.cmake (100%) rename {speechx => runtime}/cmake/eigen.cmake (100%) rename {speechx => runtime}/cmake/gflags.cmake (61%) rename {speechx => runtime}/cmake/glog.cmake (77%) rename {speechx => runtime}/cmake/gtest.cmake (76%) rename {speechx => runtime}/cmake/kenlm.cmake (100%) rename {speechx => runtime}/cmake/libsndfile.cmake (100%) rename {speechx => runtime}/cmake/openblas.cmake (100%) rename {speechx => runtime}/cmake/openfst.cmake (96%) rename {speechx => runtime}/cmake/paddleinference.cmake (100%) rename {speechx => runtime}/cmake/system.cmake (100%) rename {speechx => runtime}/docker/.gitkeep (100%) rename {speechx/speechx => runtime/engine}/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/common.h (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_beam_search_opt.h (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_prefix_beam_search_decoder.cc (99%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_prefix_beam_search_decoder.h (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_prefix_beam_search_decoder_main.cc (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_prefix_beam_search_score.h (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_tlg_decoder.cc (97%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_tlg_decoder.h (95%) rename {speechx/speechx => runtime/engine}/asr/decoder/ctc_tlg_decoder_main.cc (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/decoder_itf.h (100%) rename {speechx/speechx => runtime/engine}/asr/decoder/param.h (98%) rename {speechx/speechx => runtime/engine}/asr/nnet/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/decodable.cc (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/decodable.h (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/nnet_itf.h (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/nnet_producer.cc (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/nnet_producer.h (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/u2_nnet.cc (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/u2_nnet.h (100%) rename {speechx/speechx => runtime/engine}/asr/nnet/u2_nnet_main.cc (97%) rename {speechx/speechx => runtime/engine}/asr/nnet/u2_nnet_thread_main.cc (99%) rename {speechx/speechx => runtime/engine}/asr/recognizer/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/recognizer/u2_recognizer.cc (97%) rename {speechx/speechx => runtime/engine}/asr/recognizer/u2_recognizer.h (95%) rename {speechx/speechx => runtime/engine}/asr/recognizer/u2_recognizer_batch_main.cc (100%) rename {speechx/speechx => runtime/engine}/asr/recognizer/u2_recognizer_main.cc (100%) rename {speechx/speechx => runtime/engine}/asr/recognizer/u2_recognizer_thread_main.cc (98%) rename {speechx/speechx => runtime/engine}/asr/server/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_client.cc (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_client.h (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_client_main.cc (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_server.cc (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_server.h (100%) rename {speechx/speechx => runtime/engine}/asr/server/websocket/websocket_server_main.cc (100%) rename {speechx/speechx => runtime/engine}/codelab/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/codelab/README.md (100%) rename {speechx/speechx => runtime/engine}/codelab/glog/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/codelab/glog/README.md (100%) rename {speechx/speechx => runtime/engine}/codelab/glog/glog_logtostderr_main.cc (100%) rename {speechx/speechx => runtime/engine}/codelab/glog/glog_main.cc (100%) rename {speechx/speechx => runtime/engine}/common/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/common/base/basic_types.h (97%) rename {speechx/speechx => runtime/engine}/common/base/common.h (100%) rename {speechx/speechx => runtime/engine}/common/base/flags.h (100%) rename {speechx/speechx => runtime/engine}/common/base/log.h (100%) rename {speechx/speechx => runtime/engine}/common/base/macros.h (100%) rename {speechx/speechx => runtime/engine}/common/base/safe_queue.h (100%) rename speechx/speechx/kaldi/.gitkeep => runtime/engine/common/base/safe_queue_inl.h (100%) rename {speechx/speechx => runtime/engine}/common/base/thread_pool.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/common/frontend/assembler.cc (97%) rename {speechx/speechx => runtime/engine}/common/frontend/assembler.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/audio_cache.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/audio_cache.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/cmvn.cc (90%) rename {speechx/speechx => runtime/engine}/common/frontend/cmvn.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/compute_fbank_main.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/compute_linear_spectrogram_main.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/data_cache.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/db_norm.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/db_norm.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/fbank.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/fbank.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-fbank.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-fbank.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-functions.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-functions.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-window.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature-window.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_cache.cc (97%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_cache.h (97%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_common.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_common_inl.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_pipeline.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/feature_pipeline.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/fftsg.c (100%) rename {speechx/speechx => runtime/engine}/common/frontend/frontend_itf.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/linear_spectrogram.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/linear_spectrogram.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/mel-computations.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/mel-computations.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/mfcc.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/mfcc.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/normalizer.h (100%) rename {speechx/speechx => runtime/engine}/common/frontend/rfft.cc (100%) rename {speechx/speechx => runtime/engine}/common/frontend/rfft.h (100%) create mode 100644 runtime/engine/common/frontend/wave-reader.cc create mode 100644 runtime/engine/common/frontend/wave-reader.h rename {speechx/speechx => runtime/engine}/common/matrix/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-matrix-inl.h (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-matrix.cc (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-matrix.h (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-vector-inl.h (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-vector.cc (100%) rename {speechx/speechx => runtime/engine}/common/matrix/kaldi-vector.h (100%) rename {speechx/speechx => runtime/engine}/common/matrix/matrix-common.h (100%) rename {speechx/speechx => runtime/engine}/common/utils/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/common/utils/file_utils.cc (100%) rename {speechx/speechx => runtime/engine}/common/utils/file_utils.h (100%) rename {speechx/speechx => runtime/engine}/common/utils/math.cc (100%) rename {speechx/speechx => runtime/engine}/common/utils/math.h (100%) create mode 100644 runtime/engine/common/utils/picojson.h rename {speechx/speechx => runtime/engine}/common/utils/strings.cc (80%) rename {speechx/speechx => runtime/engine}/common/utils/strings.h (79%) rename {speechx/speechx => runtime/engine}/common/utils/strings_test.cc (81%) rename {speechx/speechx => runtime/engine}/kaldi/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/io-funcs-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/io-funcs.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/io-funcs.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-common.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-error.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-error.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-math.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-math.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-types.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-utils.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/kaldi-utils.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/timer.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/timer.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/base/version.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/decodable-itf.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/lattice-faster-decoder.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/lattice-faster-decoder.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/lattice-faster-online-decoder.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/decoder/lattice-faster-online-decoder.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/fstaddselfloops.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/fstdeterminizestar.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/fstisstochastic.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/fstminimizeencoded.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstbin/fsttablecompose.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/determinize-lattice-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/determinize-lattice.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/determinize-star-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/determinize-star.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/fstext-lib.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/fstext-utils-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/fstext-utils.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/kaldi-fst-io-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/kaldi-fst-io.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/kaldi-fst-io.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/lattice-utils-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/lattice-utils.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/lattice-weight.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/pre-determinize-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/pre-determinize.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/remove-eps-local-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/remove-eps-local.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/fstext/table-matcher.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/determinize-lattice-pruned.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/determinize-lattice-pruned.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/kaldi-lattice.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/kaldi-lattice.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/lattice-functions.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/lat/lattice-functions.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lm/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/lm/arpa-file-parser.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/lm/arpa-file-parser.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lm/arpa-lm-compiler.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/lm/arpa-lm-compiler.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/lmbin/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/lmbin/arpa2fst.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/CMakeLists.txt (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/basic-filebuf.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/common-utils.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/const-integer-set-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/const-integer-set.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/edit-distance-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/edit-distance.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/hash-list-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/hash-list.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-cygwin-io-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-holder-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-holder.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-holder.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-io-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-io.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-io.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-pipebuf.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-semaphore.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-semaphore.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-table-inl.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-table.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-table.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-thread.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/kaldi-thread.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/options-itf.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/parse-options.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/parse-options.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/simple-io-funcs.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/simple-io-funcs.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/simple-options.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/simple-options.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/stl-utils.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/table-types.h (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/text-utils.cc (100%) rename {speechx/speechx => runtime/engine}/kaldi/util/text-utils.h (100%) rename {speechx => runtime}/examples/.gitignore (80%) rename {speechx => runtime}/examples/README.md (100%) rename {speechx => runtime}/examples/codelab/README.md (100%) rename {speechx => runtime}/examples/codelab/decoder/.gitignore (100%) rename {speechx => runtime}/examples/codelab/decoder/README.md (100%) rename {speechx => runtime}/examples/codelab/decoder/path.sh (100%) rename {speechx => runtime}/examples/codelab/decoder/run.sh (100%) rename {speechx => runtime}/examples/codelab/decoder/valgrind.sh (100%) rename {speechx => runtime}/examples/codelab/feat/.gitignore (100%) rename {speechx => runtime}/examples/codelab/feat/README.md (100%) rename {speechx => runtime}/examples/codelab/feat/path.sh (100%) rename {speechx => runtime}/examples/codelab/feat/run.sh (100%) rename {speechx => runtime}/examples/codelab/feat/valgrind.sh (100%) rename {speechx => runtime}/examples/codelab/nnet/.gitignore (100%) rename {speechx => runtime}/examples/codelab/nnet/README.md (100%) rename {speechx => runtime}/examples/codelab/nnet/path.sh (100%) rename {speechx => runtime}/examples/codelab/nnet/run.sh (100%) rename {speechx => runtime}/examples/codelab/nnet/valgrind.sh (100%) rename {speechx => runtime}/examples/codelab/u2/.gitignore (100%) rename {speechx => runtime}/examples/codelab/u2/README.md (100%) rename {speechx => runtime}/examples/codelab/u2/local/decode.sh (100%) rename {speechx => runtime}/examples/codelab/u2/local/feat.sh (100%) rename {speechx => runtime}/examples/codelab/u2/local/nnet.sh (100%) rename {speechx => runtime}/examples/codelab/u2/local/recognizer.sh (100%) rename {speechx => runtime}/examples/codelab/u2/path.sh (100%) rename {speechx => runtime}/examples/codelab/u2/run.sh (100%) rename {speechx => runtime}/examples/codelab/u2/utils (100%) rename {speechx => runtime}/examples/custom_asr/README.md (100%) rename {speechx => runtime}/examples/custom_asr/local/compile_lexicon_token_fst.sh (100%) rename {speechx => runtime}/examples/custom_asr/local/mk_slot_graph.sh (100%) rename {speechx => runtime}/examples/custom_asr/local/mk_tlg_with_slot.sh (100%) rename {speechx => runtime}/examples/custom_asr/local/train_lm_with_slot.sh (100%) rename {speechx => runtime}/examples/custom_asr/path.sh (70%) rename {speechx => runtime}/examples/custom_asr/run.sh (100%) rename {speechx => runtime}/examples/custom_asr/utils (100%) rename {speechx => runtime}/examples/text_lm/.gitignore (100%) rename {speechx => runtime}/examples/text_lm/README.md (100%) rename {speechx => runtime}/examples/text_lm/local/data/chars.dic (100%) rename {speechx => runtime}/examples/text_lm/local/data/words.dic (100%) rename {speechx => runtime}/examples/text_lm/local/mmseg.py (100%) create mode 100644 runtime/examples/text_lm/path.sh rename {speechx => runtime}/examples/text_lm/run.sh (100%) rename {speechx => runtime}/examples/text_lm/utils (100%) rename {speechx => runtime}/examples/u2pp_ol/README.md (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/.gitignore (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/README.md (90%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/RESULTS.md (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/decode.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/feat.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/nnet.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/recognizer.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh (100%) create mode 100755 runtime/examples/u2pp_ol/wenetspeech/local/run_build_tlg.sh rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/local/split_data.sh (100%) create mode 100644 runtime/examples/u2pp_ol/wenetspeech/path.sh rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/run.sh (100%) rename {speechx => runtime}/examples/u2pp_ol/wenetspeech/utils (100%) rename {speechx => runtime}/patch/CPPLINT.cfg (100%) rename {speechx => runtime}/patch/README.md (100%) rename {speechx => runtime}/patch/openfst/src/include/fst/flags.h (100%) rename {speechx => runtime}/patch/openfst/src/include/fst/log.h (100%) rename {speechx => runtime}/patch/openfst/src/lib/flags.cc (100%) rename {speechx => runtime}/tools/clang-format.sh (100%) rename {speechx => runtime}/tools/setup_valgrind.sh (100%) rename {speechx => runtime}/tools/venv.sh (100%) delete mode 100644 speechx/examples/text_lm/path.sh delete mode 100644 speechx/examples/u2pp_ol/wenetspeech/path.sh delete mode 100644 speechx/speechx/common/frontend/wave-reader.cc delete mode 100644 speechx/speechx/common/frontend/wave-reader.h delete mode 100644 speechx/speechx/common/utils/picojson.h diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99461947..6afa7c9c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: entry: yapf args: [-i, -vv] types: [python] - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: a11d9314b22d8f8c7556443875b731ef05965464 @@ -35,7 +35,7 @@ repos: - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 - --builtins=G,request - --jobs=1 - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo : https://github.com/Lucas-C/pre-commit-hooks rev: v1.0.1 @@ -57,16 +57,16 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: - id: reorder-python-imports - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ diff --git a/README.md b/README.md index dbdf6a4f..0a12ec04 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision - 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation. - 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/demos/speech_ssl), Support ASR and Feature Extraction. - 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660). -- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/speechx/examples/u2pp_ol/wenetspeech). +- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech). - 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). - 🔥 2022.10.26: Add [Prosody Prediction](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/rhy) for TTS. - 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend. diff --git a/speechx/.clang-format b/runtime/.clang-format similarity index 100% rename from speechx/.clang-format rename to runtime/.clang-format diff --git a/speechx/.gitignore b/runtime/.gitignore similarity index 65% rename from speechx/.gitignore rename to runtime/.gitignore index 9a93805c..0783b138 100644 --- a/speechx/.gitignore +++ b/runtime/.gitignore @@ -1,2 +1,3 @@ tools/valgrind* *log +fc_patch/* diff --git a/speechx/CMakeLists.txt b/runtime/CMakeLists.txt similarity index 92% rename from speechx/CMakeLists.txt rename to runtime/CMakeLists.txt index d056ebbc..8bd3f28c 100644 --- a/speechx/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -93,7 +93,7 @@ endif() # paddle libpaddle.so # paddle include and link option -# -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so +# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so execute_process( COMMAND python -c "\ import os;\ @@ -112,7 +112,7 @@ message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS}) string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS) # paddle compile option -# -I/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/include +# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include execute_process( COMMAND python -c "\ import paddle; \ @@ -143,6 +143,6 @@ message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) ############################################################################### # Add local library ############################################################################### -set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) +set(ENGINE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/engine) -add_subdirectory(speechx) +add_subdirectory(engine) diff --git a/speechx/README.md b/runtime/README.md similarity index 96% rename from speechx/README.md rename to runtime/README.md index 70136ea0..40aa9444 100644 --- a/speechx/README.md +++ b/runtime/README.md @@ -1,4 +1,3 @@ -# SpeechX -- All in One Speech Task Inference ## Environment @@ -9,7 +8,7 @@ We develop under: * gcc/g++/gfortran - 8.2.0 * cmake - 3.16.0 -> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx. +> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine. > We make sure all things work fun under docker, and recommend using it to develop and deploy. @@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech bash tools/venv.sh ``` -2. Build `speechx` and `examples`. +2. Build `engine` and `examples`. For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version. For example: diff --git a/speechx/build.sh b/runtime/build.sh similarity index 100% rename from speechx/build.sh rename to runtime/build.sh diff --git a/speechx/cmake/EnableCMP0048.cmake b/runtime/cmake/EnableCMP0048.cmake similarity index 100% rename from speechx/cmake/EnableCMP0048.cmake rename to runtime/cmake/EnableCMP0048.cmake diff --git a/speechx/cmake/FindGFortranLibs.cmake b/runtime/cmake/FindGFortranLibs.cmake similarity index 100% rename from speechx/cmake/FindGFortranLibs.cmake rename to runtime/cmake/FindGFortranLibs.cmake diff --git a/speechx/cmake/absl.cmake b/runtime/cmake/absl.cmake similarity index 100% rename from speechx/cmake/absl.cmake rename to runtime/cmake/absl.cmake diff --git a/speechx/cmake/boost.cmake b/runtime/cmake/boost.cmake similarity index 100% rename from speechx/cmake/boost.cmake rename to runtime/cmake/boost.cmake diff --git a/speechx/cmake/eigen.cmake b/runtime/cmake/eigen.cmake similarity index 100% rename from speechx/cmake/eigen.cmake rename to runtime/cmake/eigen.cmake diff --git a/speechx/cmake/gflags.cmake b/runtime/cmake/gflags.cmake similarity index 61% rename from speechx/cmake/gflags.cmake rename to runtime/cmake/gflags.cmake index 36bebc87..d01eaf60 100644 --- a/speechx/cmake/gflags.cmake +++ b/runtime/cmake/gflags.cmake @@ -2,10 +2,10 @@ include(FetchContent) FetchContent_Declare( gflags - URL https://github.com/gflags/gflags/archive/v2.2.2.zip + URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 ) FetchContent_MakeAvailable(gflags) # openfst need -include_directories(${gflags_BINARY_DIR}/include) \ No newline at end of file +include_directories(${gflags_BINARY_DIR}/include) diff --git a/speechx/cmake/glog.cmake b/runtime/cmake/glog.cmake similarity index 77% rename from speechx/cmake/glog.cmake rename to runtime/cmake/glog.cmake index dcfd86c3..8cc9999b 100644 --- a/speechx/cmake/glog.cmake +++ b/runtime/cmake/glog.cmake @@ -1,7 +1,7 @@ include(FetchContent) FetchContent_Declare( glog - URL https://github.com/google/glog/archive/v0.4.0.zip + URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc ) FetchContent_MakeAvailable(glog) diff --git a/speechx/cmake/gtest.cmake b/runtime/cmake/gtest.cmake similarity index 76% rename from speechx/cmake/gtest.cmake rename to runtime/cmake/gtest.cmake index 365f25cf..f3e72d26 100644 --- a/speechx/cmake/gtest.cmake +++ b/runtime/cmake/gtest.cmake @@ -2,7 +2,7 @@ include(FetchContent) FetchContent_Declare( gtest - URL https://github.com/google/googletest/archive/release-1.11.0.zip + URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a ) FetchContent_MakeAvailable(gtest) @@ -12,4 +12,4 @@ include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) if(WITH_TESTING) enable_testing() -endif() \ No newline at end of file +endif() diff --git a/speechx/cmake/kenlm.cmake b/runtime/cmake/kenlm.cmake similarity index 100% rename from speechx/cmake/kenlm.cmake rename to runtime/cmake/kenlm.cmake diff --git a/speechx/cmake/libsndfile.cmake b/runtime/cmake/libsndfile.cmake similarity index 100% rename from speechx/cmake/libsndfile.cmake rename to runtime/cmake/libsndfile.cmake diff --git a/speechx/cmake/openblas.cmake b/runtime/cmake/openblas.cmake similarity index 100% rename from speechx/cmake/openblas.cmake rename to runtime/cmake/openblas.cmake diff --git a/speechx/cmake/openfst.cmake b/runtime/cmake/openfst.cmake similarity index 96% rename from speechx/cmake/openfst.cmake rename to runtime/cmake/openfst.cmake index 8861f4f4..2e2f82f2 100644 --- a/speechx/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -1,8 +1,8 @@ -include(FetchContent) set(openfst_PREFIX_DIR ${fc_patch}/openfst) set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) set(openfst_BINARY_DIR ${fc_patch}/openfst-build) +include(FetchContent) # openfst Acknowledgments: #Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri, #"OpenFst: A General and Efficient Weighted Finite-State Transducer Library", @@ -25,5 +25,7 @@ ExternalProject_Add(openfst ) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) + + message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include") -message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib") \ No newline at end of file +message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib") diff --git a/speechx/cmake/paddleinference.cmake b/runtime/cmake/paddleinference.cmake similarity index 100% rename from speechx/cmake/paddleinference.cmake rename to runtime/cmake/paddleinference.cmake diff --git a/speechx/cmake/system.cmake b/runtime/cmake/system.cmake similarity index 100% rename from speechx/cmake/system.cmake rename to runtime/cmake/system.cmake diff --git a/speechx/docker/.gitkeep b/runtime/docker/.gitkeep similarity index 100% rename from speechx/docker/.gitkeep rename to runtime/docker/.gitkeep diff --git a/speechx/speechx/CMakeLists.txt b/runtime/engine/CMakeLists.txt similarity index 100% rename from speechx/speechx/CMakeLists.txt rename to runtime/engine/CMakeLists.txt diff --git a/speechx/speechx/asr/CMakeLists.txt b/runtime/engine/asr/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/CMakeLists.txt rename to runtime/engine/asr/CMakeLists.txt diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/runtime/engine/asr/decoder/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/decoder/CMakeLists.txt rename to runtime/engine/asr/decoder/CMakeLists.txt diff --git a/speechx/speechx/asr/decoder/common.h b/runtime/engine/asr/decoder/common.h similarity index 100% rename from speechx/speechx/asr/decoder/common.h rename to runtime/engine/asr/decoder/common.h diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_opt.h b/runtime/engine/asr/decoder/ctc_beam_search_opt.h similarity index 100% rename from speechx/speechx/asr/decoder/ctc_beam_search_opt.h rename to runtime/engine/asr/decoder/ctc_beam_search_opt.h diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc similarity index 99% rename from speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc index 8361f06d..3e3ca2c2 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -63,9 +63,7 @@ void CTCPrefixBeamSearch::Reset() { times_.emplace_back(empty); } -void CTCPrefixBeamSearch::InitDecoder() { - Reset(); -} +void CTCPrefixBeamSearch::InitDecoder() { Reset(); } void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h similarity index 100% rename from speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc similarity index 100% rename from speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_score.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h similarity index 100% rename from speechx/speechx/asr/decoder/ctc_prefix_beam_search_score.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc similarity index 97% rename from speechx/speechx/asr/decoder/ctc_tlg_decoder.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder.cc index ca7d65c8..ac30da92 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc @@ -29,6 +29,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { void TLGDecoder::Reset() { decoder_->InitDecoding(); + hypotheses_.clear(); + likelihood_.clear(); + olabels_.clear(); + times_.clear(); + num_frame_decoded_ = 0; return; } @@ -103,7 +108,7 @@ void TLGDecoder::FinalizeSearch() { time.push_back(idx); // fake time, todo later hypotheses_.push_back(hypothese); times_.push_back(time); - olabels.push_back(words_id); + olabels_.push_back(words_id); likelihood_.push_back(-(weight.Value2() + weight.Value1())); } } diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h b/runtime/engine/asr/decoder/ctc_tlg_decoder.h similarity index 95% rename from speechx/speechx/asr/decoder/ctc_tlg_decoder.h rename to runtime/engine/asr/decoder/ctc_tlg_decoder.h index 1ea6d634..4540bc46 100644 --- a/speechx/speechx/asr/decoder/ctc_tlg_decoder.h +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.h @@ -24,6 +24,7 @@ DECLARE_string(graph_path); DECLARE_int32(max_active); DECLARE_double(beam); DECLARE_double(lattice_beam); +DECLARE_int32(nbest); namespace ppspeech { @@ -46,7 +47,7 @@ struct TLGDecoderOptions { decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; - // decoder_opts.nbest = FLAGS_lattice_nbest; + decoder_opts.nbest = FLAGS_nbest; LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; @@ -85,7 +86,7 @@ class TLGDecoder : public DecoderBase { return hypotheses_; } const std::vector>& Outputs() const override { - return olabels; + return olabels_; } // outputs_; } const std::vector& Likelihood() const override { return likelihood_; @@ -111,8 +112,9 @@ class TLGDecoder : public DecoderBase { private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); + int num_frame_decoded_; std::vector> hypotheses_; - std::vector> olabels; + std::vector> olabels_; std::vector likelihood_; std::vector> times_; @@ -123,4 +125,4 @@ class TLGDecoder : public DecoderBase { }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc similarity index 100% rename from speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc diff --git a/speechx/speechx/asr/decoder/decoder_itf.h b/runtime/engine/asr/decoder/decoder_itf.h similarity index 100% rename from speechx/speechx/asr/decoder/decoder_itf.h rename to runtime/engine/asr/decoder/decoder_itf.h diff --git a/speechx/speechx/asr/decoder/param.h b/runtime/engine/asr/decoder/param.h similarity index 98% rename from speechx/speechx/asr/decoder/param.h rename to runtime/engine/asr/decoder/param.h index 83e2c7fb..b9e9fd20 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/runtime/engine/asr/decoder/param.h @@ -15,7 +15,6 @@ #pragma once #include "base/common.h" -//#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/runtime/engine/asr/nnet/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/nnet/CMakeLists.txt rename to runtime/engine/asr/nnet/CMakeLists.txt diff --git a/speechx/speechx/asr/nnet/decodable.cc b/runtime/engine/asr/nnet/decodable.cc similarity index 100% rename from speechx/speechx/asr/nnet/decodable.cc rename to runtime/engine/asr/nnet/decodable.cc diff --git a/speechx/speechx/asr/nnet/decodable.h b/runtime/engine/asr/nnet/decodable.h similarity index 100% rename from speechx/speechx/asr/nnet/decodable.h rename to runtime/engine/asr/nnet/decodable.h diff --git a/speechx/speechx/asr/nnet/nnet_itf.h b/runtime/engine/asr/nnet/nnet_itf.h similarity index 100% rename from speechx/speechx/asr/nnet/nnet_itf.h rename to runtime/engine/asr/nnet/nnet_itf.h diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc similarity index 100% rename from speechx/speechx/asr/nnet/nnet_producer.cc rename to runtime/engine/asr/nnet/nnet_producer.cc diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/runtime/engine/asr/nnet/nnet_producer.h similarity index 100% rename from speechx/speechx/asr/nnet/nnet_producer.h rename to runtime/engine/asr/nnet/nnet_producer.h diff --git a/speechx/speechx/asr/nnet/u2_nnet.cc b/runtime/engine/asr/nnet/u2_nnet.cc similarity index 100% rename from speechx/speechx/asr/nnet/u2_nnet.cc rename to runtime/engine/asr/nnet/u2_nnet.cc diff --git a/speechx/speechx/asr/nnet/u2_nnet.h b/runtime/engine/asr/nnet/u2_nnet.h similarity index 100% rename from speechx/speechx/asr/nnet/u2_nnet.h rename to runtime/engine/asr/nnet/u2_nnet.h diff --git a/speechx/speechx/asr/nnet/u2_nnet_main.cc b/runtime/engine/asr/nnet/u2_nnet_main.cc similarity index 97% rename from speechx/speechx/asr/nnet/u2_nnet_main.cc rename to runtime/engine/asr/nnet/u2_nnet_main.cc index e60ae7e8..699f4258 100644 --- a/speechx/speechx/asr/nnet/u2_nnet_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_main.cc @@ -13,13 +13,13 @@ // limitations under the License. +#include "nnet/u2_nnet.h" #include "base/common.h" #include "decoder/param.h" #include "frontend/assembler.h" #include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); @@ -93,9 +93,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) - << "utt: " << utt << " skip last " << this_chunk_size - << " frames, expect is " << receptive_field_length; + LOG(WARNING) << "utt: " << utt << " skip last " + << this_chunk_size << " frames, expect is " + << receptive_field_length; break; } diff --git a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc similarity index 99% rename from speechx/speechx/asr/nnet/u2_nnet_thread_main.cc rename to runtime/engine/asr/nnet/u2_nnet_thread_main.cc index ce523e59..4339bdbe 100644 --- a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc @@ -13,13 +13,13 @@ // limitations under the License. +#include "nnet/u2_nnet.h" #include "base/common.h" #include "decoder/param.h" -#include "frontend/wave-reader.h" #include "frontend/feature_pipeline.h" +#include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" -#include "nnet/u2_nnet.h" #include "nnet/nnet_producer.h" DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); @@ -104,7 +104,7 @@ int main(int argc, char* argv[]) { CHECK(sample_offset == tot_samples); std::vector> prob_vec; - while(1) { + while (1) { std::vector logprobs; bool isok = nnet_producer->Read(&logprobs); if (nnet_producer->IsFinished()) break; diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/runtime/engine/asr/recognizer/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/recognizer/CMakeLists.txt rename to runtime/engine/asr/recognizer/CMakeLists.txt diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/runtime/engine/asr/recognizer/u2_recognizer.cc similarity index 97% rename from speechx/speechx/asr/recognizer/u2_recognizer.cc rename to runtime/engine/asr/recognizer/u2_recognizer.cc index f31ceb3b..da1348f5 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer.cc @@ -33,12 +33,12 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); - if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") { + if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) { LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path; - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); + decoder_ = std::make_unique( + resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts); } else { - decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts)); + decoder_ = std::make_unique(resource.decoder_opts.tlg_decoder_opts); } symbol_table_ = decoder_->WordSymbolTable(); @@ -268,4 +268,4 @@ void U2Recognizer::SetInputFinished() { } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/runtime/engine/asr/recognizer/u2_recognizer.h similarity index 95% rename from speechx/speechx/asr/recognizer/u2_recognizer.h rename to runtime/engine/asr/recognizer/u2_recognizer.h index 889da85b..299b64ed 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/runtime/engine/asr/recognizer/u2_recognizer.h @@ -31,11 +31,9 @@ DECLARE_double(rescoring_weight); DECLARE_double(reverse_weight); DECLARE_int32(nbest); DECLARE_int32(blank); - DECLARE_double(acoustic_scale); DECLARE_string(vocab_path); DECLARE_string(word_symbol_table); -// DECLARE_string(fst_path); namespace ppspeech { @@ -74,10 +72,6 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; - // decoder_opts.tlg_decoder_opts.fst_path = "";//FLAGS_fst_path; - // decoder_opts.tlg_decoder_opts.word_symbol_table = - // FLAGS_word_symbol_table; - // decoder_opts.tlg_decoder_opts.nbest = FLAGS_nbest; decoder_opts.tlg_decoder_opts = ppspeech::TLGDecoderOptions::InitFromFlags(); @@ -183,4 +177,4 @@ class U2Recognizer { std::thread thread_; }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc similarity index 100% rename from speechx/speechx/asr/recognizer/u2_recognizer_batch_main.cc rename to runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_main.cc similarity index 100% rename from speechx/speechx/asr/recognizer/u2_recognizer_main.cc rename to runtime/engine/asr/recognizer/u2_recognizer_main.cc index 178c91db..fb37d050 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_main.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "recognizer/u2_recognizer.h" #include "decoder/param.h" #include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" -#include "recognizer/u2_recognizer.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc similarity index 98% rename from speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc rename to runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc index 3f45294d..b86853fa 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc @@ -100,7 +100,7 @@ int main(int argc, char* argv[]) { continue; } - tot_decode_time += local_timer.Elapsed(); + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); diff --git a/speechx/speechx/asr/server/CMakeLists.txt b/runtime/engine/asr/server/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/server/CMakeLists.txt rename to runtime/engine/asr/server/CMakeLists.txt diff --git a/speechx/speechx/asr/server/websocket/CMakeLists.txt b/runtime/engine/asr/server/websocket/CMakeLists.txt similarity index 100% rename from speechx/speechx/asr/server/websocket/CMakeLists.txt rename to runtime/engine/asr/server/websocket/CMakeLists.txt diff --git a/speechx/speechx/asr/server/websocket/websocket_client.cc b/runtime/engine/asr/server/websocket/websocket_client.cc similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_client.cc rename to runtime/engine/asr/server/websocket/websocket_client.cc diff --git a/speechx/speechx/asr/server/websocket/websocket_client.h b/runtime/engine/asr/server/websocket/websocket_client.h similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_client.h rename to runtime/engine/asr/server/websocket/websocket_client.h diff --git a/speechx/speechx/asr/server/websocket/websocket_client_main.cc b/runtime/engine/asr/server/websocket/websocket_client_main.cc similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_client_main.cc rename to runtime/engine/asr/server/websocket/websocket_client_main.cc index 7c5a4f2f..7ad36e3a 100644 --- a/speechx/speechx/asr/server/websocket/websocket_client_main.cc +++ b/runtime/engine/asr/server/websocket/websocket_client_main.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "websocket/websocket_client.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" -#include "websocket/websocket_client.h" DEFINE_string(host, "127.0.0.1", "host of websocket server"); DEFINE_int32(port, 8082, "port of websocket server"); diff --git a/speechx/speechx/asr/server/websocket/websocket_server.cc b/runtime/engine/asr/server/websocket/websocket_server.cc similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_server.cc rename to runtime/engine/asr/server/websocket/websocket_server.cc diff --git a/speechx/speechx/asr/server/websocket/websocket_server.h b/runtime/engine/asr/server/websocket/websocket_server.h similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_server.h rename to runtime/engine/asr/server/websocket/websocket_server.h diff --git a/speechx/speechx/asr/server/websocket/websocket_server_main.cc b/runtime/engine/asr/server/websocket/websocket_server_main.cc similarity index 100% rename from speechx/speechx/asr/server/websocket/websocket_server_main.cc rename to runtime/engine/asr/server/websocket/websocket_server_main.cc index 5c32caf2..5f805ac9 100644 --- a/speechx/speechx/asr/server/websocket/websocket_server_main.cc +++ b/runtime/engine/asr/server/websocket/websocket_server_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/param.h" #include "websocket/websocket_server.h" +#include "decoder/param.h" DEFINE_int32(port, 8082, "websocket listening port"); diff --git a/speechx/speechx/codelab/CMakeLists.txt b/runtime/engine/codelab/CMakeLists.txt similarity index 100% rename from speechx/speechx/codelab/CMakeLists.txt rename to runtime/engine/codelab/CMakeLists.txt diff --git a/speechx/speechx/codelab/README.md b/runtime/engine/codelab/README.md similarity index 100% rename from speechx/speechx/codelab/README.md rename to runtime/engine/codelab/README.md diff --git a/speechx/speechx/codelab/glog/CMakeLists.txt b/runtime/engine/codelab/glog/CMakeLists.txt similarity index 100% rename from speechx/speechx/codelab/glog/CMakeLists.txt rename to runtime/engine/codelab/glog/CMakeLists.txt diff --git a/speechx/speechx/codelab/glog/README.md b/runtime/engine/codelab/glog/README.md similarity index 100% rename from speechx/speechx/codelab/glog/README.md rename to runtime/engine/codelab/glog/README.md diff --git a/speechx/speechx/codelab/glog/glog_logtostderr_main.cc b/runtime/engine/codelab/glog/glog_logtostderr_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_logtostderr_main.cc rename to runtime/engine/codelab/glog/glog_logtostderr_main.cc diff --git a/speechx/speechx/codelab/glog/glog_main.cc b/runtime/engine/codelab/glog/glog_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_main.cc rename to runtime/engine/codelab/glog/glog_main.cc diff --git a/speechx/speechx/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt similarity index 100% rename from speechx/speechx/common/CMakeLists.txt rename to runtime/engine/common/CMakeLists.txt diff --git a/speechx/speechx/common/base/basic_types.h b/runtime/engine/common/base/basic_types.h similarity index 97% rename from speechx/speechx/common/base/basic_types.h rename to runtime/engine/common/base/basic_types.h index 2b15a61f..c7fdc924 100644 --- a/speechx/speechx/common/base/basic_types.h +++ b/runtime/engine/common/base/basic_types.h @@ -28,7 +28,7 @@ typedef int int32; // NOLINT #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) typedef long int64; // NOLINT #else -typedef long long int64; // NOLINT +typedef long long int64; // NOLINT #endif typedef unsigned char uint8; // NOLINT diff --git a/speechx/speechx/common/base/common.h b/runtime/engine/common/base/common.h similarity index 100% rename from speechx/speechx/common/base/common.h rename to runtime/engine/common/base/common.h index 06fcd9fd..d94dc8a8 100644 --- a/speechx/speechx/common/base/common.h +++ b/runtime/engine/common/base/common.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -42,8 +44,6 @@ #include #include #include -#include -#include #include "base/basic_types.h" #include "base/flags.h" diff --git a/speechx/speechx/common/base/flags.h b/runtime/engine/common/base/flags.h similarity index 100% rename from speechx/speechx/common/base/flags.h rename to runtime/engine/common/base/flags.h diff --git a/speechx/speechx/common/base/log.h b/runtime/engine/common/base/log.h similarity index 100% rename from speechx/speechx/common/base/log.h rename to runtime/engine/common/base/log.h diff --git a/speechx/speechx/common/base/macros.h b/runtime/engine/common/base/macros.h similarity index 100% rename from speechx/speechx/common/base/macros.h rename to runtime/engine/common/base/macros.h diff --git a/speechx/speechx/common/base/safe_queue.h b/runtime/engine/common/base/safe_queue.h similarity index 100% rename from speechx/speechx/common/base/safe_queue.h rename to runtime/engine/common/base/safe_queue.h diff --git a/speechx/speechx/kaldi/.gitkeep b/runtime/engine/common/base/safe_queue_inl.h similarity index 100% rename from speechx/speechx/kaldi/.gitkeep rename to runtime/engine/common/base/safe_queue_inl.h diff --git a/speechx/speechx/common/base/thread_pool.h b/runtime/engine/common/base/thread_pool.h similarity index 100% rename from speechx/speechx/common/base/thread_pool.h rename to runtime/engine/common/base/thread_pool.h diff --git a/speechx/speechx/common/frontend/CMakeLists.txt b/runtime/engine/common/frontend/CMakeLists.txt similarity index 100% rename from speechx/speechx/common/frontend/CMakeLists.txt rename to runtime/engine/common/frontend/CMakeLists.txt diff --git a/speechx/speechx/common/frontend/assembler.cc b/runtime/engine/common/frontend/assembler.cc similarity index 97% rename from speechx/speechx/common/frontend/assembler.cc rename to runtime/engine/common/frontend/assembler.cc index 5f019c42..487951cd 100644 --- a/speechx/speechx/common/frontend/assembler.cc +++ b/runtime/engine/common/frontend/assembler.cc @@ -97,8 +97,8 @@ bool Assembler::Compute(vector* feats) { CHECK(val.size() == dim_) << val.size(); int32 start = counter * dim_; - std::memcpy(feats->data() + start, - val.data(), val.size() * sizeof(BaseFloat)); + std::memcpy( + feats->data() + start, val.data(), val.size() * sizeof(BaseFloat)); if (this_chunk_size - counter <= cache_size_) { feature_cache_.push(val); diff --git a/speechx/speechx/common/frontend/assembler.h b/runtime/engine/common/frontend/assembler.h similarity index 100% rename from speechx/speechx/common/frontend/assembler.h rename to runtime/engine/common/frontend/assembler.h diff --git a/speechx/speechx/common/frontend/audio_cache.cc b/runtime/engine/common/frontend/audio_cache.cc similarity index 100% rename from speechx/speechx/common/frontend/audio_cache.cc rename to runtime/engine/common/frontend/audio_cache.cc diff --git a/speechx/speechx/common/frontend/audio_cache.h b/runtime/engine/common/frontend/audio_cache.h similarity index 100% rename from speechx/speechx/common/frontend/audio_cache.h rename to runtime/engine/common/frontend/audio_cache.h diff --git a/speechx/speechx/common/frontend/cmvn.cc b/runtime/engine/common/frontend/cmvn.cc similarity index 90% rename from speechx/speechx/common/frontend/cmvn.cc rename to runtime/engine/common/frontend/cmvn.cc index 2fac1506..8375d3d1 100644 --- a/speechx/speechx/common/frontend/cmvn.cc +++ b/runtime/engine/common/frontend/cmvn.cc @@ -84,11 +84,12 @@ void CMVN::Compute(vector* feats) const { KALDI_ASSERT(feats != NULL); if (feats->size() % dim_ != 0) { - LOG(ERROR)<< "Dim mismatch: cmvn " << mean_stats_.size() << ',' - << var_stats_.size() - 1 << ", feats " << feats->size() << 'x'; + LOG(ERROR) << "Dim mismatch: cmvn " << mean_stats_.size() << ',' + << var_stats_.size() - 1 << ", feats " << feats->size() + << 'x'; } if (var_stats_.size() == 0 && var_norm_) { - LOG(ERROR) + LOG(ERROR) << "You requested variance normalization but no variance stats_ " << "are supplied."; } @@ -98,8 +99,8 @@ void CMVN::Compute(vector* feats) const { // computing an offset and representing it as stats_, we use a count of one. if (count < 1.0) LOG(ERROR) << "Insufficient stats_ for cepstral mean and variance " - "normalization: " - << "count = " << count; + "normalization: " + << "count = " << count; if (!var_norm_) { vector offset(feats->size()); @@ -112,11 +113,12 @@ void CMVN::Compute(vector* feats) const { // with the dim_ of feature. // the dim_ of feats = dim_ * num_frames; for (int32 idx = 0; idx < feats->size() / dim_; ++idx) { - std::memcpy(mean_stats_apply.data() + dim_ * idx, - mean_stats.data(), dim_* sizeof(double)); + std::memcpy(mean_stats_apply.data() + dim_ * idx, + mean_stats.data(), + dim_ * sizeof(double)); } for (size_t idx = 0; idx < feats->size(); ++idx) { - feats->at(idx) += offset[idx]; + feats->at(idx) += offset[idx]; } return; } @@ -130,7 +132,7 @@ void CMVN::Compute(vector* feats) const { double var = (var_stats_[d] / count) - mean * mean, floor = 1.0e-20; if (var < floor) { LOG(WARNING) << "Flooring cepstral variance from " << var << " to " - << floor; + << floor; var = floor; } scale = 1.0 / sqrt(var); @@ -146,7 +148,7 @@ void CMVN::Compute(vector* feats) const { } // Apply the normalization. for (size_t idx = 0; idx < feats->size(); ++idx) { - feats->at(idx) *= norm1[idx]; + feats->at(idx) *= norm1[idx]; } for (size_t idx = 0; idx < feats->size(); ++idx) { diff --git a/speechx/speechx/common/frontend/cmvn.h b/runtime/engine/common/frontend/cmvn.h similarity index 100% rename from speechx/speechx/common/frontend/cmvn.h rename to runtime/engine/common/frontend/cmvn.h diff --git a/speechx/speechx/common/frontend/compute_fbank_main.cc b/runtime/engine/common/frontend/compute_fbank_main.cc similarity index 100% rename from speechx/speechx/common/frontend/compute_fbank_main.cc rename to runtime/engine/common/frontend/compute_fbank_main.cc diff --git a/speechx/speechx/common/frontend/compute_linear_spectrogram_main.cc b/runtime/engine/common/frontend/compute_linear_spectrogram_main.cc similarity index 100% rename from speechx/speechx/common/frontend/compute_linear_spectrogram_main.cc rename to runtime/engine/common/frontend/compute_linear_spectrogram_main.cc diff --git a/speechx/speechx/common/frontend/data_cache.h b/runtime/engine/common/frontend/data_cache.h similarity index 100% rename from speechx/speechx/common/frontend/data_cache.h rename to runtime/engine/common/frontend/data_cache.h diff --git a/speechx/speechx/common/frontend/db_norm.cc b/runtime/engine/common/frontend/db_norm.cc similarity index 100% rename from speechx/speechx/common/frontend/db_norm.cc rename to runtime/engine/common/frontend/db_norm.cc diff --git a/speechx/speechx/common/frontend/db_norm.h b/runtime/engine/common/frontend/db_norm.h similarity index 100% rename from speechx/speechx/common/frontend/db_norm.h rename to runtime/engine/common/frontend/db_norm.h diff --git a/speechx/speechx/common/frontend/fbank.cc b/runtime/engine/common/frontend/fbank.cc similarity index 100% rename from speechx/speechx/common/frontend/fbank.cc rename to runtime/engine/common/frontend/fbank.cc diff --git a/speechx/speechx/common/frontend/fbank.h b/runtime/engine/common/frontend/fbank.h similarity index 100% rename from speechx/speechx/common/frontend/fbank.h rename to runtime/engine/common/frontend/fbank.h index 61d9c9aa..4398e72f 100644 --- a/speechx/speechx/common/frontend/fbank.h +++ b/runtime/engine/common/frontend/fbank.h @@ -15,8 +15,8 @@ #pragma once #include "base/common.h" -#include "frontend/feature_common.h" #include "frontend/feature-fbank.h" +#include "frontend/feature_common.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/feature-fbank.cc b/runtime/engine/common/frontend/feature-fbank.cc similarity index 100% rename from speechx/speechx/common/frontend/feature-fbank.cc rename to runtime/engine/common/frontend/feature-fbank.cc diff --git a/speechx/speechx/common/frontend/feature-fbank.h b/runtime/engine/common/frontend/feature-fbank.h similarity index 100% rename from speechx/speechx/common/frontend/feature-fbank.h rename to runtime/engine/common/frontend/feature-fbank.h diff --git a/speechx/speechx/common/frontend/feature-functions.cc b/runtime/engine/common/frontend/feature-functions.cc similarity index 100% rename from speechx/speechx/common/frontend/feature-functions.cc rename to runtime/engine/common/frontend/feature-functions.cc diff --git a/speechx/speechx/common/frontend/feature-functions.h b/runtime/engine/common/frontend/feature-functions.h similarity index 100% rename from speechx/speechx/common/frontend/feature-functions.h rename to runtime/engine/common/frontend/feature-functions.h diff --git a/speechx/speechx/common/frontend/feature-window.cc b/runtime/engine/common/frontend/feature-window.cc similarity index 100% rename from speechx/speechx/common/frontend/feature-window.cc rename to runtime/engine/common/frontend/feature-window.cc diff --git a/speechx/speechx/common/frontend/feature-window.h b/runtime/engine/common/frontend/feature-window.h similarity index 100% rename from speechx/speechx/common/frontend/feature-window.h rename to runtime/engine/common/frontend/feature-window.h diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/runtime/engine/common/frontend/feature_cache.cc similarity index 97% rename from speechx/speechx/common/frontend/feature_cache.cc rename to runtime/engine/common/frontend/feature_cache.cc index c166bd64..bdbe6931 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/runtime/engine/common/frontend/feature_cache.cc @@ -67,7 +67,7 @@ bool FeatureCache::Compute() { for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; - vector feature_chunk(feature.data() + start, + vector feature_chunk(feature.data() + start, feature.data() + start + dim_); // feed cache cache_.push(feature_chunk); diff --git a/speechx/speechx/common/frontend/feature_cache.h b/runtime/engine/common/frontend/feature_cache.h similarity index 97% rename from speechx/speechx/common/frontend/feature_cache.h rename to runtime/engine/common/frontend/feature_cache.h index b87612d6..25cb86f8 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/runtime/engine/common/frontend/feature_cache.h @@ -57,7 +57,7 @@ class FeatureCache : public FrontendInterface { bool Compute(); int32 dim_; - size_t max_size_; // cache capacity + size_t max_size_; // cache capacity std::unique_ptr base_extractor_; std::queue> cache_; // feature cache diff --git a/speechx/speechx/common/frontend/feature_common.h b/runtime/engine/common/frontend/feature_common.h similarity index 100% rename from speechx/speechx/common/frontend/feature_common.h rename to runtime/engine/common/frontend/feature_common.h index 7864bd30..fcc9100c 100644 --- a/speechx/speechx/common/frontend/feature_common.h +++ b/runtime/engine/common/frontend/feature_common.h @@ -14,8 +14,8 @@ #pragma once -#include "frontend_itf.h" #include "frontend/feature-window.h" +#include "frontend_itf.h" namespace ppspeech { diff --git a/speechx/speechx/common/frontend/feature_common_inl.h b/runtime/engine/common/frontend/feature_common_inl.h similarity index 100% rename from speechx/speechx/common/frontend/feature_common_inl.h rename to runtime/engine/common/frontend/feature_common_inl.h diff --git a/speechx/speechx/common/frontend/feature_pipeline.cc b/runtime/engine/common/frontend/feature_pipeline.cc similarity index 100% rename from speechx/speechx/common/frontend/feature_pipeline.cc rename to runtime/engine/common/frontend/feature_pipeline.cc diff --git a/speechx/speechx/common/frontend/feature_pipeline.h b/runtime/engine/common/frontend/feature_pipeline.h similarity index 100% rename from speechx/speechx/common/frontend/feature_pipeline.h rename to runtime/engine/common/frontend/feature_pipeline.h index c9a649fd..7509814f 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.h +++ b/runtime/engine/common/frontend/feature_pipeline.h @@ -18,11 +18,11 @@ #include "frontend/assembler.h" #include "frontend/audio_cache.h" +#include "frontend/cmvn.h" #include "frontend/data_cache.h" #include "frontend/fbank.h" #include "frontend/feature_cache.h" #include "frontend/frontend_itf.h" -#include "frontend/cmvn.h" // feature DECLARE_bool(fill_zero); diff --git a/speechx/speechx/common/frontend/fftsg.c b/runtime/engine/common/frontend/fftsg.c similarity index 100% rename from speechx/speechx/common/frontend/fftsg.c rename to runtime/engine/common/frontend/fftsg.c diff --git a/speechx/speechx/common/frontend/frontend_itf.h b/runtime/engine/common/frontend/frontend_itf.h similarity index 100% rename from speechx/speechx/common/frontend/frontend_itf.h rename to runtime/engine/common/frontend/frontend_itf.h diff --git a/speechx/speechx/common/frontend/linear_spectrogram.cc b/runtime/engine/common/frontend/linear_spectrogram.cc similarity index 100% rename from speechx/speechx/common/frontend/linear_spectrogram.cc rename to runtime/engine/common/frontend/linear_spectrogram.cc diff --git a/speechx/speechx/common/frontend/linear_spectrogram.h b/runtime/engine/common/frontend/linear_spectrogram.h similarity index 100% rename from speechx/speechx/common/frontend/linear_spectrogram.h rename to runtime/engine/common/frontend/linear_spectrogram.h diff --git a/speechx/speechx/common/frontend/mel-computations.cc b/runtime/engine/common/frontend/mel-computations.cc similarity index 100% rename from speechx/speechx/common/frontend/mel-computations.cc rename to runtime/engine/common/frontend/mel-computations.cc diff --git a/speechx/speechx/common/frontend/mel-computations.h b/runtime/engine/common/frontend/mel-computations.h similarity index 100% rename from speechx/speechx/common/frontend/mel-computations.h rename to runtime/engine/common/frontend/mel-computations.h diff --git a/speechx/speechx/common/frontend/mfcc.cc b/runtime/engine/common/frontend/mfcc.cc similarity index 100% rename from speechx/speechx/common/frontend/mfcc.cc rename to runtime/engine/common/frontend/mfcc.cc diff --git a/speechx/speechx/common/frontend/mfcc.h b/runtime/engine/common/frontend/mfcc.h similarity index 100% rename from speechx/speechx/common/frontend/mfcc.h rename to runtime/engine/common/frontend/mfcc.h diff --git a/speechx/speechx/common/frontend/normalizer.h b/runtime/engine/common/frontend/normalizer.h similarity index 100% rename from speechx/speechx/common/frontend/normalizer.h rename to runtime/engine/common/frontend/normalizer.h diff --git a/speechx/speechx/common/frontend/rfft.cc b/runtime/engine/common/frontend/rfft.cc similarity index 100% rename from speechx/speechx/common/frontend/rfft.cc rename to runtime/engine/common/frontend/rfft.cc diff --git a/speechx/speechx/common/frontend/rfft.h b/runtime/engine/common/frontend/rfft.h similarity index 100% rename from speechx/speechx/common/frontend/rfft.h rename to runtime/engine/common/frontend/rfft.h diff --git a/runtime/engine/common/frontend/wave-reader.cc b/runtime/engine/common/frontend/wave-reader.cc new file mode 100644 index 00000000..b64dcc9e --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.cc @@ -0,0 +1,376 @@ +// feat/wave-reader.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "base/kaldi-error.h" +#include "base/kaldi-utils.h" +#include "frontend/wave-reader.h" + +namespace kaldi { + +// A utility class for reading wave header. +struct WaveHeaderReadGofer { + std::istream &is; + bool swap; + char tag[5]; + + WaveHeaderReadGofer(std::istream &is) : is(is), swap(false) { + memset(tag, '\0', sizeof tag); + } + + void Expect4ByteTag(const char *expected) { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected " << expected + << ", failed to read anything"; + if (strcmp(tag, expected)) + KALDI_ERR << "WaveData: expected " << expected << ", got " << tag; + } + + void Read4ByteTag() { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected 4-byte chunk-name, got read error"; + } + + uint32 ReadUint32() { + union { + char result[4]; + uint32 ans; + } u; + is.read(u.result, 4); + if (swap) KALDI_SWAP4(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } + + uint16 ReadUint16() { + union { + char result[2]; + int16 ans; + } u; + is.read(u.result, 2); + if (swap) KALDI_SWAP2(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } +}; + +static void WriteUint32(std::ostream &os, int32 i) { + union { + char buf[4]; + int i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP4(u.buf); +#endif + os.write(u.buf, 4); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +static void WriteUint16(std::ostream &os, int16 i) { + union { + char buf[2]; + int16 i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(u.buf); +#endif + os.write(u.buf, 2); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +void WaveInfo::Read(std::istream &is) { + WaveHeaderReadGofer reader(is); + reader.Read4ByteTag(); + if (strcmp(reader.tag, "RIFF") == 0) + reverse_bytes_ = false; + else if (strcmp(reader.tag, "RIFX") == 0) + reverse_bytes_ = true; + else + KALDI_ERR << "WaveData: expected RIFF or RIFX, got " << reader.tag; + +#ifdef __BIG_ENDIAN__ + reverse_bytes_ = !reverse_bytes_; +#endif + reader.swap = reverse_bytes_; + + uint32 riff_chunk_size = reader.ReadUint32(); + reader.Expect4ByteTag("WAVE"); + + uint32 riff_chunk_read = 0; + riff_chunk_read += 4; // WAVE included in riff_chunk_size. + + // Possibly skip any RIFF tags between 'WAVE' and 'fmt '. + // Apple devices produce a filler tag 'JUNK' for memory alignment. + reader.Read4ByteTag(); + riff_chunk_read += 4; + while (strcmp(reader.tag, "fmt ") != 0) { + uint32 filler_size = reader.ReadUint32(); + riff_chunk_read += 4; + for (uint32 i = 0; i < filler_size; i++) { + is.get(); // read 1 byte, + } + riff_chunk_read += filler_size; + // get next RIFF tag, + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "fmt ") == 0); + uint32 subchunk1_size = reader.ReadUint32(); + uint16 audio_format = reader.ReadUint16(); + num_channels_ = reader.ReadUint16(); + uint32 sample_rate = reader.ReadUint32(), byte_rate = reader.ReadUint32(), + block_align = reader.ReadUint16(), + bits_per_sample = reader.ReadUint16(); + samp_freq_ = static_cast(sample_rate); + + uint32 fmt_chunk_read = 16; + if (audio_format == 1) { + if (subchunk1_size < 16) { + KALDI_ERR << "WaveData: expect PCM format data to have fmt chunk " + << "of at least size 16."; + } + } else if (audio_format == 0xFFFE) { // WAVE_FORMAT_EXTENSIBLE + uint16 extra_size = reader.ReadUint16(); + if (subchunk1_size < 40 || extra_size < 22) { + KALDI_ERR + << "WaveData: malformed WAVE_FORMAT_EXTENSIBLE format data."; + } + reader.ReadUint16(); // Unused for PCM. + reader.ReadUint32(); // Channel map: we do not care. + uint32 guid1 = reader.ReadUint32(), guid2 = reader.ReadUint32(), + guid3 = reader.ReadUint32(), guid4 = reader.ReadUint32(); + fmt_chunk_read = 40; + + // Support only KSDATAFORMAT_SUBTYPE_PCM for now. Interesting formats: + // ("00000001-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_PCM) + // ("00000003-0000-0010-8000-00aa00389b71", + // KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) + // ("00000006-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_ALAW) + // ("00000007-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_MULAW) + if (guid1 != 0x00000001 || guid2 != 0x00100000 || guid3 != 0xAA000080 || + guid4 != 0x719B3800) { + KALDI_ERR << "WaveData: unsupported WAVE_FORMAT_EXTENSIBLE format."; + } + } else { + KALDI_ERR << "WaveData: can read only PCM data, format id in file is: " + << audio_format; + } + + for (uint32 i = fmt_chunk_read; i < subchunk1_size; ++i) + is.get(); // use up extra data. + + if (num_channels_ == 0) KALDI_ERR << "WaveData: no channels present"; + if (bits_per_sample != 16) + KALDI_ERR << "WaveData: unsupported bits_per_sample = " + << bits_per_sample; + if (byte_rate != sample_rate * bits_per_sample / 8 * num_channels_) + KALDI_ERR << "Unexpected byte rate " << byte_rate << " vs. " + << sample_rate << " * " << (bits_per_sample / 8) << " * " + << num_channels_; + if (block_align != num_channels_ * bits_per_sample / 8) + KALDI_ERR << "Unexpected block_align: " << block_align << " vs. " + << num_channels_ << " * " << (bits_per_sample / 8); + + riff_chunk_read += 4 + subchunk1_size; + // size of what we just read, 4 for subchunk1_size + subchunk1_size itself. + + // We support an optional "fact" chunk (which is useless but which + // we encountered), and then a single "data" chunk. + + reader.Read4ByteTag(); + riff_chunk_read += 4; + + // Skip any subchunks between "fmt" and "data". Usually there will + // be a single "fact" subchunk, but on Windows there can also be a + // "list" subchunk. + while (strcmp(reader.tag, "data") != 0) { + // We will just ignore the data in these chunks. + uint32 chunk_sz = reader.ReadUint32(); + if (chunk_sz != 4 && strcmp(reader.tag, "fact") == 0) + KALDI_WARN << "Expected fact chunk to be 4 bytes long."; + for (uint32 i = 0; i < chunk_sz; i++) is.get(); + riff_chunk_read += + 4 + chunk_sz; // for chunk_sz (4) + chunk contents (chunk-sz) + + // Now read the next chunk name. + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "data") == 0); + uint32 data_chunk_size = reader.ReadUint32(); + riff_chunk_read += 4; + + // Figure out if the file is going to be read to the end. Values as + // observed in the wild: + bool is_stream_mode = + riff_chunk_size == 0 || riff_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0 || data_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0x7FFFF000; // This value is used by SoX. + + if (is_stream_mode) + KALDI_VLOG(1) << "Read in RIFF chunk size: " << riff_chunk_size + << ", data chunk size: " << data_chunk_size + << ". Assume 'stream mode' (reading data to EOF)."; + + if (!is_stream_mode && + std::abs(static_cast(riff_chunk_read) + + static_cast(data_chunk_size) - + static_cast(riff_chunk_size)) > 1) { + // We allow the size to be off by one without warning, because there is + // a + // weirdness in the format of RIFF files that means that the input may + // sometimes be padded with 1 unused byte to make the total size even. + KALDI_WARN << "Expected " << riff_chunk_size + << " bytes in RIFF chunk, but " + << "after first data block there will be " << riff_chunk_read + << " + " << data_chunk_size << " bytes " + << "(we do not support reading multiple data chunks)."; + } + + if (is_stream_mode) + samp_count_ = -1; + else + samp_count_ = data_chunk_size / block_align; +} + +void WaveData::Read(std::istream &is) { + const uint32 kBlockSize = 1024 * 1024; + + WaveInfo header; + header.Read(is); + + data_.Resize(0, 0); // clear the data. + samp_freq_ = header.SampFreq(); + + std::vector buffer; + uint32 bytes_to_go = header.IsStreamed() ? kBlockSize : header.DataBytes(); + + // Once in a while header.DataBytes() will report an insane value; + // read the file to the end + while (is && bytes_to_go > 0) { + uint32 block_bytes = std::min(bytes_to_go, kBlockSize); + uint32 offset = buffer.size(); + buffer.resize(offset + block_bytes); + is.read(&buffer[offset], block_bytes); + uint32 bytes_read = is.gcount(); + buffer.resize(offset + bytes_read); + if (!header.IsStreamed()) bytes_to_go -= bytes_read; + } + + if (is.bad()) KALDI_ERR << "WaveData: file read error"; + + if (buffer.size() == 0) KALDI_ERR << "WaveData: empty file (no data)"; + + if (!header.IsStreamed() && buffer.size() < header.DataBytes()) { + KALDI_WARN << "Expected " << header.DataBytes() + << " bytes of wave data, " + << "but read only " << buffer.size() << " bytes. " + << "Truncated file?"; + } + + uint16 *data_ptr = reinterpret_cast(&buffer[0]); + + // The matrix is arranged row per channel, column per sample. + data_.Resize(header.NumChannels(), buffer.size() / header.BlockAlign()); + for (uint32 i = 0; i < data_.NumCols(); ++i) { + for (uint32 j = 0; j < data_.NumRows(); ++j) { + int16 k = *data_ptr++; + if (header.ReverseBytes()) KALDI_SWAP2(k); + data_(j, i) = k; + } + } +} + + +// Write 16-bit PCM. + +// note: the WAVE chunk contains 2 subchunks. +// +// subchunk2size = data.NumRows() * data.NumCols() * 2. + + +void WaveData::Write(std::ostream &os) const { + os << "RIFF"; + if (data_.NumRows() == 0) + KALDI_ERR << "Error: attempting to write empty WAVE file"; + + int32 num_chan = data_.NumRows(), num_samp = data_.NumCols(), + bytes_per_samp = 2; + + int32 subchunk2size = (num_chan * num_samp * bytes_per_samp); + int32 chunk_size = 36 + subchunk2size; + WriteUint32(os, chunk_size); + os << "WAVE"; + os << "fmt "; + WriteUint32(os, 16); + WriteUint16(os, 1); + WriteUint16(os, num_chan); + KALDI_ASSERT(samp_freq_ > 0); + WriteUint32(os, static_cast(samp_freq_)); + WriteUint32(os, static_cast(samp_freq_) * num_chan * bytes_per_samp); + WriteUint16(os, num_chan * bytes_per_samp); + WriteUint16(os, 8 * bytes_per_samp); + os << "data"; + WriteUint32(os, subchunk2size); + + const BaseFloat *data_ptr = data_.Data(); + int32 stride = data_.Stride(); + + int num_clipped = 0; + for (int32 i = 0; i < num_samp; i++) { + for (int32 j = 0; j < num_chan; j++) { + int32 elem = static_cast(trunc(data_ptr[j * stride + i])); + int16 elem_16 = static_cast(elem); + if (elem < std::numeric_limits::min()) { + elem_16 = std::numeric_limits::min(); + ++num_clipped; + } else if (elem > std::numeric_limits::max()) { + elem_16 = std::numeric_limits::max(); + ++num_clipped; + } +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(elem_16); +#endif + os.write(reinterpret_cast(&elem_16), 2); + } + } + if (os.fail()) KALDI_ERR << "Error writing wave data to stream."; + if (num_clipped > 0) + KALDI_WARN << "WARNING: clipped " << num_clipped + << " samples out of total " << num_chan * num_samp + << ". Reduce volume?"; +} + + +} // end namespace kaldi diff --git a/runtime/engine/common/frontend/wave-reader.h b/runtime/engine/common/frontend/wave-reader.h new file mode 100644 index 00000000..6cd471b8 --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.h @@ -0,0 +1,248 @@ +// feat/wave-reader.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/* +// THE WAVE FORMAT IS SPECIFIED IN: +// https:// ccrma.stanford.edu/courses/422/projects/WaveFormat/ +// +// +// +// RIFF +// | +// WAVE +// | \ \ \ +// fmt_ data ... data +// +// +// Riff is a general container, which usually contains one WAVE chunk +// each WAVE chunk has header sub-chunk 'fmt_' +// and one or more data sub-chunks 'data' +// +// [Note from Dan: to say that the wave format was ever "specified" anywhere is +// not quite right. The guy who invented the wave format attempted to create +// a formal specification but it did not completely make sense. And there +// doesn't seem to be a consensus on what makes a valid wave file, +// particularly where the accuracy of header information is concerned.] +*/ + + +#ifndef KALDI_FEAT_WAVE_READER_H_ +#define KALDI_FEAT_WAVE_READER_H_ + +#include + +#include "base/kaldi-types.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" + + +namespace kaldi { + +/// For historical reasons, we scale waveforms to the range +/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. +const BaseFloat kWaveSampleMax = 32768.0; + +/// This class reads and hold wave file header information. +class WaveInfo { + public: + WaveInfo() + : samp_freq_(0), samp_count_(0), num_channels_(0), reverse_bytes_(0) {} + + /// Is stream size unknown? Duration and SampleCount not valid if true. + bool IsStreamed() const { return samp_count_ < 0; } + + /// Sample frequency, Hz. + BaseFloat SampFreq() const { return samp_freq_; } + + /// Number of samples in stream. Invalid if IsStreamed() is true. + uint32 SampleCount() const { return samp_count_; } + + /// Approximate duration, seconds. Invalid if IsStreamed() is true. + BaseFloat Duration() const { return samp_count_ / samp_freq_; } + + /// Number of channels, 1 to 16. + int32 NumChannels() const { return num_channels_; } + + /// Bytes per sample. + size_t BlockAlign() const { return 2 * num_channels_; } + + /// Wave data bytes. Invalid if IsStreamed() is true. + size_t DataBytes() const { return samp_count_ * BlockAlign(); } + + /// Is data file byte order different from machine byte order? + bool ReverseBytes() const { return reverse_bytes_; } + + /// 'is' should be opened in binary mode. Read() will throw on error. + /// On success 'is' will be positioned at the beginning of wave data. + void Read(std::istream &is); + + private: + BaseFloat samp_freq_; + int32 samp_count_; // 0 if empty, -1 if undefined length. + uint8 num_channels_; + bool reverse_bytes_; // File endianness differs from host. +}; + +/// This class's purpose is to read in Wave files. +class WaveData { + public: + WaveData(BaseFloat samp_freq, const MatrixBase &data) + : data_(data), samp_freq_(samp_freq) {} + + WaveData() : samp_freq_(0.0) {} + + /// Read() will throw on error. It's valid to call Read() more than once-- + /// in this case it will destroy what was there before. + /// "is" should be opened in binary mode. + void Read(std::istream &is); + + /// Write() will throw on error. os should be opened in binary mode. + void Write(std::ostream &os) const; + + // This function returns the wave data-- it's in a matrix + // because there may be multiple channels. In the normal case + // there's just one channel so Data() will have one row. + const Matrix &Data() const { return data_; } + + BaseFloat SampFreq() const { return samp_freq_; } + + // Returns the duration in seconds + BaseFloat Duration() const { return data_.NumCols() / samp_freq_; } + + void CopyFrom(const WaveData &other) { + samp_freq_ = other.samp_freq_; + data_.CopyFromMat(other.data_); + } + + void Clear() { + data_.Resize(0, 0); + samp_freq_ = 0.0; + } + + void Swap(WaveData *other) { + data_.Swap(&(other->data_)); + std::swap(samp_freq_, other->samp_freq_); + } + + private: + static const uint32 kBlockSize = 1024 * 1024; // Use 1M bytes. + Matrix data_; + BaseFloat samp_freq_; +}; + + +// Holder class for .wav files that enables us to read (but not write) .wav +// files. c.f. util/kaldi-holder.h we don't use the KaldiObjectHolder template +// because we don't want to check for the \0B binary header. We could have faked +// it by pretending to read in the wave data in text mode after failing to find +// the \0B header, but that would have been a little ugly. +class WaveHolder { + public: + typedef WaveData T; + + static bool Write(std::ostream &os, bool binary, const T &t) { + // We don't write the binary-mode header here [always binary]. + if (!binary) + KALDI_ERR << "Wave data can only be written in binary mode."; + try { + t.Write(os); // throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder object (writing). " + << e.what(); + return false; // write failure. + } + } + void Copy(const T &t) { t_.CopyFrom(t); } + + static bool IsReadInBinary() { return true; } + + void Clear() { t_.Clear(); } + + T &Value() { return t_; } + + WaveHolder &operator=(const WaveHolder &other) { + t_.CopyFrom(other.t_); + return *this; + } + WaveHolder(const WaveHolder &other) : t_(other.t_) {} + + WaveHolder() {} + + bool Read(std::istream &is) { + // We don't look for the binary-mode header here [always binary] + try { + t_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder::Read(). " + << e.what(); + return false; + } + } + + void Swap(WaveHolder *other) { t_.Swap(&(other->t_)); } + + bool ExtractRange(const WaveHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + T t_; +}; + +// This is like WaveHolder but when you just want the metadata- +// it leaves the actual data undefined, it doesn't read it. +class WaveInfoHolder { + public: + typedef WaveInfo T; + + void Clear() { info_ = WaveInfo(); } + void Swap(WaveInfoHolder *other) { std::swap(info_, other->info_); } + T &Value() { return info_; } + static bool IsReadInBinary() { return true; } + + bool Read(std::istream &is) { + try { + info_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveInfoHolder::Read(). " + << e.what(); + return false; + } + } + + bool ExtractRange(const WaveInfoHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + WaveInfo info_; +}; + + +} // namespace kaldi + +#endif // KALDI_FEAT_WAVE_READER_H_ diff --git a/speechx/speechx/common/matrix/CMakeLists.txt b/runtime/engine/common/matrix/CMakeLists.txt similarity index 100% rename from speechx/speechx/common/matrix/CMakeLists.txt rename to runtime/engine/common/matrix/CMakeLists.txt diff --git a/speechx/speechx/common/matrix/kaldi-matrix-inl.h b/runtime/engine/common/matrix/kaldi-matrix-inl.h similarity index 100% rename from speechx/speechx/common/matrix/kaldi-matrix-inl.h rename to runtime/engine/common/matrix/kaldi-matrix-inl.h diff --git a/speechx/speechx/common/matrix/kaldi-matrix.cc b/runtime/engine/common/matrix/kaldi-matrix.cc similarity index 100% rename from speechx/speechx/common/matrix/kaldi-matrix.cc rename to runtime/engine/common/matrix/kaldi-matrix.cc diff --git a/speechx/speechx/common/matrix/kaldi-matrix.h b/runtime/engine/common/matrix/kaldi-matrix.h similarity index 100% rename from speechx/speechx/common/matrix/kaldi-matrix.h rename to runtime/engine/common/matrix/kaldi-matrix.h diff --git a/speechx/speechx/common/matrix/kaldi-vector-inl.h b/runtime/engine/common/matrix/kaldi-vector-inl.h similarity index 100% rename from speechx/speechx/common/matrix/kaldi-vector-inl.h rename to runtime/engine/common/matrix/kaldi-vector-inl.h diff --git a/speechx/speechx/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc similarity index 100% rename from speechx/speechx/common/matrix/kaldi-vector.cc rename to runtime/engine/common/matrix/kaldi-vector.cc diff --git a/speechx/speechx/common/matrix/kaldi-vector.h b/runtime/engine/common/matrix/kaldi-vector.h similarity index 100% rename from speechx/speechx/common/matrix/kaldi-vector.h rename to runtime/engine/common/matrix/kaldi-vector.h diff --git a/speechx/speechx/common/matrix/matrix-common.h b/runtime/engine/common/matrix/matrix-common.h similarity index 100% rename from speechx/speechx/common/matrix/matrix-common.h rename to runtime/engine/common/matrix/matrix-common.h diff --git a/speechx/speechx/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt similarity index 100% rename from speechx/speechx/common/utils/CMakeLists.txt rename to runtime/engine/common/utils/CMakeLists.txt diff --git a/speechx/speechx/common/utils/file_utils.cc b/runtime/engine/common/utils/file_utils.cc similarity index 100% rename from speechx/speechx/common/utils/file_utils.cc rename to runtime/engine/common/utils/file_utils.cc diff --git a/speechx/speechx/common/utils/file_utils.h b/runtime/engine/common/utils/file_utils.h similarity index 100% rename from speechx/speechx/common/utils/file_utils.h rename to runtime/engine/common/utils/file_utils.h diff --git a/speechx/speechx/common/utils/math.cc b/runtime/engine/common/utils/math.cc similarity index 100% rename from speechx/speechx/common/utils/math.cc rename to runtime/engine/common/utils/math.cc index e5832cbd..1f0c9c93 100644 --- a/speechx/speechx/common/utils/math.cc +++ b/runtime/engine/common/utils/math.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include diff --git a/speechx/speechx/common/utils/math.h b/runtime/engine/common/utils/math.h similarity index 100% rename from speechx/speechx/common/utils/math.h rename to runtime/engine/common/utils/math.h diff --git a/runtime/engine/common/utils/picojson.h b/runtime/engine/common/utils/picojson.h new file mode 100644 index 00000000..2ac265f5 --- /dev/null +++ b/runtime/engine/common/utils/picojson.h @@ -0,0 +1,1230 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || \ + (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { + public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + + protected: + int type_; + _storage u_; + + public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template + bool is() const; + template + const T &get() const; + template + T &get(); + template + void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template + void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template + void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + + private: + template + value(const T *); // intentionally defined to block implicit conversion of + // pointer to bool + template + static void _indent(Iter os, int indent); + template + void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() {} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { u_.boolean_ = b; } + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { u_.int64_ = i; } +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { clear(); } + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> \ + inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> +inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> \ + inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } \ + template <> \ + inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && + (const_cast(this)->type_ = number_type, + (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> \ + inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> \ + inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF( + buf, + sizeof(buf), + fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 + ? "%.f" + : "%.17g", + u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template +void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template +struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template +void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template +void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template +void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template +void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); + i != u_.array_->end(); + ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); + i != u_.object_->end(); + ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template +class input { + protected: + Iter cur_, end_; + bool consumed_; + int line_; + + public: + input(const Iter &first, const Iter &last) + : cur_(first), end_(last), consumed_(false), line_(1) {} + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { consumed_ = false; } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { return line_; } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); + pi != pattern.end(); + ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template +inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template +inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back( + static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template +inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template +inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template +inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template +inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || + ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template +inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && + std::numeric_limits::min() <= ival && + ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { + public: + bool set_null() { return false; } + bool set_bool(bool) { return false; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return false; } +#endif + bool set_number(double) { return false; } + template + bool parse_string(input &) { + return false; + } + bool parse_array_start() { return false; } + template + bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { return false; } + bool parse_object_start() { return false; } + template + bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { + protected: + value *out_; + size_t depths_; + + public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) + : out_(out), depths_(depths) {} + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template + bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template + bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + *out_ = value(object_type, false); + return true; + } + template + bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + + private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { + protected: + size_t depths_; + + public: + struct dummy_str { + void push_back(int) {} + }; + + public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) {} + bool set_null() { return true; } + bool set_bool(bool) { return true; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return true; } +#endif + bool set_number(double) { return true; } + template + bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { return true; } + + private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template +inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template +inline Iter _parse(Context &ctx, + const Iter &first, + const Iter &last, + std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template +inline Iter parse(value &out, + const Iter &first, + const Iter &last, + std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, + std::istreambuf_iterator(is.rdbuf()), + std::istreambuf_iterator(), + &err); + return err; +} + +template +struct last_error_t { + static std::string s; +}; +template +std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { last_error_t::s = s; } + +inline const std::string &get_last_error() { return last_error_t::s; } + +inline bool operator==(const value &x, const value &y) { + if (x.is()) return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { return !(x == y); } +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> +inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file diff --git a/speechx/speechx/common/utils/strings.cc b/runtime/engine/common/utils/strings.cc similarity index 80% rename from speechx/speechx/common/utils/strings.cc rename to runtime/engine/common/utils/strings.cc index 6aa8af47..e453cf65 100644 --- a/speechx/speechx/common/utils/strings.cc +++ b/runtime/engine/common/utils/strings.cc @@ -18,15 +18,17 @@ namespace ppspeech { -std::vector StrSplit(const std::string& str, const char *delim, bool omit_empty_string){ +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string) { std::vector outs; int start = 0; int end = str.size(); int found = 0; - while(found != std::string::npos){ + while (found != std::string::npos) { found = str.find_first_of(delim, start); // start != end condition is for when the delimiter is at the end - if (!omit_empty_string || (found != start && start != end)){ + if (!omit_empty_string || (found != start && start != end)) { outs.push_back(str.substr(start, found - start)); } start = found + 1; @@ -38,13 +40,13 @@ std::vector StrSplit(const std::string& str, const char *delim, boo std::string StrJoin(const std::vector& strs, const char* delim) { std::stringstream ss; - for (ssize_t i = 0; i < strs.size(); ++i){ + for (ssize_t i = 0; i < strs.size(); ++i) { ss << strs[i]; - if ( i < strs.size() -1){ + if (i < strs.size() - 1) { ss << std::string(delim); } } return ss.str(); } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/common/utils/strings.h b/runtime/engine/common/utils/strings.h similarity index 79% rename from speechx/speechx/common/utils/strings.h rename to runtime/engine/common/utils/strings.h index e2629164..175506a5 100644 --- a/speechx/speechx/common/utils/strings.h +++ b/runtime/engine/common/utils/strings.h @@ -14,13 +14,15 @@ #pragma once -#include #include +#include namespace ppspeech { -std::vector StrSplit(const std::string& str, const char *delim, bool omit_empty_string=true); +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string = true); std::string StrJoin(const std::vector& strs, const char* delim); -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/common/utils/strings_test.cc b/runtime/engine/common/utils/strings_test.cc similarity index 81% rename from speechx/speechx/common/utils/strings_test.cc rename to runtime/engine/common/utils/strings_test.cc index a2950d32..f158a532 100644 --- a/speechx/speechx/common/utils/strings_test.cc +++ b/runtime/engine/common/utils/strings_test.cc @@ -15,16 +15,16 @@ #include "utils/strings.h" -#include #include +#include TEST(StringTest, StrSplitTest) { - using ::testing::ElementsAre; + using ::testing::ElementsAre; - std::string test_str = "hello world"; - std::vector outs = ppspeech::StrSplit(test_str, " \t"); - EXPECT_THAT(outs, ElementsAre("hello", "world")); + std::string test_str = "hello world"; + std::vector outs = ppspeech::StrSplit(test_str, " \t"); + EXPECT_THAT(outs, ElementsAre("hello", "world")); } diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/runtime/engine/kaldi/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/CMakeLists.txt rename to runtime/engine/kaldi/CMakeLists.txt diff --git a/speechx/speechx/kaldi/base/CMakeLists.txt b/runtime/engine/kaldi/base/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/base/CMakeLists.txt rename to runtime/engine/kaldi/base/CMakeLists.txt diff --git a/speechx/speechx/kaldi/base/io-funcs-inl.h b/runtime/engine/kaldi/base/io-funcs-inl.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs-inl.h rename to runtime/engine/kaldi/base/io-funcs-inl.h diff --git a/speechx/speechx/kaldi/base/io-funcs.cc b/runtime/engine/kaldi/base/io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.cc rename to runtime/engine/kaldi/base/io-funcs.cc diff --git a/speechx/speechx/kaldi/base/io-funcs.h b/runtime/engine/kaldi/base/io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.h rename to runtime/engine/kaldi/base/io-funcs.h diff --git a/speechx/speechx/kaldi/base/kaldi-common.h b/runtime/engine/kaldi/base/kaldi-common.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-common.h rename to runtime/engine/kaldi/base/kaldi-common.h diff --git a/speechx/speechx/kaldi/base/kaldi-error.cc b/runtime/engine/kaldi/base/kaldi-error.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-error.cc rename to runtime/engine/kaldi/base/kaldi-error.cc diff --git a/speechx/speechx/kaldi/base/kaldi-error.h b/runtime/engine/kaldi/base/kaldi-error.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-error.h rename to runtime/engine/kaldi/base/kaldi-error.h diff --git a/speechx/speechx/kaldi/base/kaldi-math.cc b/runtime/engine/kaldi/base/kaldi-math.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.cc rename to runtime/engine/kaldi/base/kaldi-math.cc diff --git a/speechx/speechx/kaldi/base/kaldi-math.h b/runtime/engine/kaldi/base/kaldi-math.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.h rename to runtime/engine/kaldi/base/kaldi-math.h diff --git a/speechx/speechx/kaldi/base/kaldi-types.h b/runtime/engine/kaldi/base/kaldi-types.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-types.h rename to runtime/engine/kaldi/base/kaldi-types.h diff --git a/speechx/speechx/kaldi/base/kaldi-utils.cc b/runtime/engine/kaldi/base/kaldi-utils.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.cc rename to runtime/engine/kaldi/base/kaldi-utils.cc diff --git a/speechx/speechx/kaldi/base/kaldi-utils.h b/runtime/engine/kaldi/base/kaldi-utils.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.h rename to runtime/engine/kaldi/base/kaldi-utils.h diff --git a/speechx/speechx/kaldi/base/timer.cc b/runtime/engine/kaldi/base/timer.cc similarity index 100% rename from speechx/speechx/kaldi/base/timer.cc rename to runtime/engine/kaldi/base/timer.cc diff --git a/speechx/speechx/kaldi/base/timer.h b/runtime/engine/kaldi/base/timer.h similarity index 100% rename from speechx/speechx/kaldi/base/timer.h rename to runtime/engine/kaldi/base/timer.h diff --git a/speechx/speechx/kaldi/base/version.h b/runtime/engine/kaldi/base/version.h similarity index 100% rename from speechx/speechx/kaldi/base/version.h rename to runtime/engine/kaldi/base/version.h diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/runtime/engine/kaldi/decoder/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/decoder/CMakeLists.txt rename to runtime/engine/kaldi/decoder/CMakeLists.txt diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/runtime/engine/kaldi/decoder/decodable-itf.h similarity index 100% rename from speechx/speechx/kaldi/decoder/decodable-itf.h rename to runtime/engine/kaldi/decoder/decodable-itf.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h diff --git a/speechx/speechx/kaldi/fstbin/CMakeLists.txt b/runtime/engine/kaldi/fstbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstbin/CMakeLists.txt rename to runtime/engine/kaldi/fstbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstbin/fstaddselfloops.cc b/runtime/engine/kaldi/fstbin/fstaddselfloops.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstaddselfloops.cc rename to runtime/engine/kaldi/fstbin/fstaddselfloops.cc diff --git a/speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc b/runtime/engine/kaldi/fstbin/fstdeterminizestar.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc rename to runtime/engine/kaldi/fstbin/fstdeterminizestar.cc diff --git a/speechx/speechx/kaldi/fstbin/fstisstochastic.cc b/runtime/engine/kaldi/fstbin/fstisstochastic.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstisstochastic.cc rename to runtime/engine/kaldi/fstbin/fstisstochastic.cc diff --git a/speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc b/runtime/engine/kaldi/fstbin/fstminimizeencoded.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc rename to runtime/engine/kaldi/fstbin/fstminimizeencoded.cc diff --git a/speechx/speechx/kaldi/fstbin/fsttablecompose.cc b/runtime/engine/kaldi/fstbin/fsttablecompose.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fsttablecompose.cc rename to runtime/engine/kaldi/fstbin/fsttablecompose.cc diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/runtime/engine/kaldi/fstext/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstext/CMakeLists.txt rename to runtime/engine/kaldi/fstext/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/runtime/engine/kaldi/fstext/determinize-lattice-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice-inl.h rename to runtime/engine/kaldi/fstext/determinize-lattice-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice.h b/runtime/engine/kaldi/fstext/determinize-lattice.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice.h rename to runtime/engine/kaldi/fstext/determinize-lattice.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star-inl.h b/runtime/engine/kaldi/fstext/determinize-star-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star-inl.h rename to runtime/engine/kaldi/fstext/determinize-star-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star.h b/runtime/engine/kaldi/fstext/determinize-star.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star.h rename to runtime/engine/kaldi/fstext/determinize-star.h diff --git a/speechx/speechx/kaldi/fstext/fstext-lib.h b/runtime/engine/kaldi/fstext/fstext-lib.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-lib.h rename to runtime/engine/kaldi/fstext/fstext-lib.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils-inl.h b/runtime/engine/kaldi/fstext/fstext-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils-inl.h rename to runtime/engine/kaldi/fstext/fstext-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils.h b/runtime/engine/kaldi/fstext/fstext-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils.h rename to runtime/engine/kaldi/fstext/fstext-utils.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h b/runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.cc b/runtime/engine/kaldi/fstext/kaldi-fst-io.cc similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.cc rename to runtime/engine/kaldi/fstext/kaldi-fst-io.cc diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.h b/runtime/engine/kaldi/fstext/kaldi-fst-io.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils-inl.h b/runtime/engine/kaldi/fstext/lattice-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils-inl.h rename to runtime/engine/kaldi/fstext/lattice-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils.h b/runtime/engine/kaldi/fstext/lattice-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils.h rename to runtime/engine/kaldi/fstext/lattice-utils.h diff --git a/speechx/speechx/kaldi/fstext/lattice-weight.h b/runtime/engine/kaldi/fstext/lattice-weight.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-weight.h rename to runtime/engine/kaldi/fstext/lattice-weight.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize-inl.h b/runtime/engine/kaldi/fstext/pre-determinize-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize-inl.h rename to runtime/engine/kaldi/fstext/pre-determinize-inl.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize.h b/runtime/engine/kaldi/fstext/pre-determinize.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize.h rename to runtime/engine/kaldi/fstext/pre-determinize.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local-inl.h b/runtime/engine/kaldi/fstext/remove-eps-local-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local-inl.h rename to runtime/engine/kaldi/fstext/remove-eps-local-inl.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local.h b/runtime/engine/kaldi/fstext/remove-eps-local.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local.h rename to runtime/engine/kaldi/fstext/remove-eps-local.h diff --git a/speechx/speechx/kaldi/fstext/table-matcher.h b/runtime/engine/kaldi/fstext/table-matcher.h similarity index 100% rename from speechx/speechx/kaldi/fstext/table-matcher.h rename to runtime/engine/kaldi/fstext/table-matcher.h diff --git a/speechx/speechx/kaldi/lat/CMakeLists.txt b/runtime/engine/kaldi/lat/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lat/CMakeLists.txt rename to runtime/engine/kaldi/lat/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc b/runtime/engine/kaldi/lat/determinize-lattice-pruned.cc similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.cc diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.h b/runtime/engine/kaldi/lat/determinize-lattice-pruned.h similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.h rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.h diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.cc b/runtime/engine/kaldi/lat/kaldi-lattice.cc similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.cc rename to runtime/engine/kaldi/lat/kaldi-lattice.cc diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.h b/runtime/engine/kaldi/lat/kaldi-lattice.h similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.h rename to runtime/engine/kaldi/lat/kaldi-lattice.h diff --git a/speechx/speechx/kaldi/lat/lattice-functions.cc b/runtime/engine/kaldi/lat/lattice-functions.cc similarity index 100% rename from speechx/speechx/kaldi/lat/lattice-functions.cc rename to runtime/engine/kaldi/lat/lattice-functions.cc diff --git a/speechx/speechx/kaldi/lat/lattice-functions.h b/runtime/engine/kaldi/lat/lattice-functions.h similarity index 100% rename from speechx/speechx/kaldi/lat/lattice-functions.h rename to runtime/engine/kaldi/lat/lattice-functions.h diff --git a/speechx/speechx/kaldi/lm/CMakeLists.txt b/runtime/engine/kaldi/lm/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lm/CMakeLists.txt rename to runtime/engine/kaldi/lm/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.cc b/runtime/engine/kaldi/lm/arpa-file-parser.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.cc rename to runtime/engine/kaldi/lm/arpa-file-parser.cc diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.h b/runtime/engine/kaldi/lm/arpa-file-parser.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.h rename to runtime/engine/kaldi/lm/arpa-file-parser.h diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc b/runtime/engine/kaldi/lm/arpa-lm-compiler.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.cc rename to runtime/engine/kaldi/lm/arpa-lm-compiler.cc diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.h b/runtime/engine/kaldi/lm/arpa-lm-compiler.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.h rename to runtime/engine/kaldi/lm/arpa-lm-compiler.h diff --git a/speechx/speechx/kaldi/lmbin/CMakeLists.txt b/runtime/engine/kaldi/lmbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lmbin/CMakeLists.txt rename to runtime/engine/kaldi/lmbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lmbin/arpa2fst.cc b/runtime/engine/kaldi/lmbin/arpa2fst.cc similarity index 100% rename from speechx/speechx/kaldi/lmbin/arpa2fst.cc rename to runtime/engine/kaldi/lmbin/arpa2fst.cc diff --git a/speechx/speechx/kaldi/util/CMakeLists.txt b/runtime/engine/kaldi/util/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/util/CMakeLists.txt rename to runtime/engine/kaldi/util/CMakeLists.txt diff --git a/speechx/speechx/kaldi/util/basic-filebuf.h b/runtime/engine/kaldi/util/basic-filebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/basic-filebuf.h rename to runtime/engine/kaldi/util/basic-filebuf.h diff --git a/speechx/speechx/kaldi/util/common-utils.h b/runtime/engine/kaldi/util/common-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/common-utils.h rename to runtime/engine/kaldi/util/common-utils.h diff --git a/speechx/speechx/kaldi/util/const-integer-set-inl.h b/runtime/engine/kaldi/util/const-integer-set-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set-inl.h rename to runtime/engine/kaldi/util/const-integer-set-inl.h diff --git a/speechx/speechx/kaldi/util/const-integer-set.h b/runtime/engine/kaldi/util/const-integer-set.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set.h rename to runtime/engine/kaldi/util/const-integer-set.h diff --git a/speechx/speechx/kaldi/util/edit-distance-inl.h b/runtime/engine/kaldi/util/edit-distance-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance-inl.h rename to runtime/engine/kaldi/util/edit-distance-inl.h diff --git a/speechx/speechx/kaldi/util/edit-distance.h b/runtime/engine/kaldi/util/edit-distance.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance.h rename to runtime/engine/kaldi/util/edit-distance.h diff --git a/speechx/speechx/kaldi/util/hash-list-inl.h b/runtime/engine/kaldi/util/hash-list-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list-inl.h rename to runtime/engine/kaldi/util/hash-list-inl.h diff --git a/speechx/speechx/kaldi/util/hash-list.h b/runtime/engine/kaldi/util/hash-list.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list.h rename to runtime/engine/kaldi/util/hash-list.h diff --git a/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h b/runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h rename to runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-holder-inl.h b/runtime/engine/kaldi/util/kaldi-holder-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-holder-inl.h rename to runtime/engine/kaldi/util/kaldi-holder-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-holder.cc b/runtime/engine/kaldi/util/kaldi-holder.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-holder.cc rename to runtime/engine/kaldi/util/kaldi-holder.cc diff --git a/speechx/speechx/kaldi/util/kaldi-holder.h b/runtime/engine/kaldi/util/kaldi-holder.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-holder.h rename to runtime/engine/kaldi/util/kaldi-holder.h diff --git a/speechx/speechx/kaldi/util/kaldi-io-inl.h b/runtime/engine/kaldi/util/kaldi-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io-inl.h rename to runtime/engine/kaldi/util/kaldi-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-io.cc b/runtime/engine/kaldi/util/kaldi-io.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.cc rename to runtime/engine/kaldi/util/kaldi-io.cc diff --git a/speechx/speechx/kaldi/util/kaldi-io.h b/runtime/engine/kaldi/util/kaldi-io.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.h rename to runtime/engine/kaldi/util/kaldi-io.h diff --git a/speechx/speechx/kaldi/util/kaldi-pipebuf.h b/runtime/engine/kaldi/util/kaldi-pipebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-pipebuf.h rename to runtime/engine/kaldi/util/kaldi-pipebuf.h diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.cc b/runtime/engine/kaldi/util/kaldi-semaphore.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.cc rename to runtime/engine/kaldi/util/kaldi-semaphore.cc diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.h b/runtime/engine/kaldi/util/kaldi-semaphore.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.h rename to runtime/engine/kaldi/util/kaldi-semaphore.h diff --git a/speechx/speechx/kaldi/util/kaldi-table-inl.h b/runtime/engine/kaldi/util/kaldi-table-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table-inl.h rename to runtime/engine/kaldi/util/kaldi-table-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-table.cc b/runtime/engine/kaldi/util/kaldi-table.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.cc rename to runtime/engine/kaldi/util/kaldi-table.cc diff --git a/speechx/speechx/kaldi/util/kaldi-table.h b/runtime/engine/kaldi/util/kaldi-table.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.h rename to runtime/engine/kaldi/util/kaldi-table.h diff --git a/speechx/speechx/kaldi/util/kaldi-thread.cc b/runtime/engine/kaldi/util/kaldi-thread.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.cc rename to runtime/engine/kaldi/util/kaldi-thread.cc diff --git a/speechx/speechx/kaldi/util/kaldi-thread.h b/runtime/engine/kaldi/util/kaldi-thread.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.h rename to runtime/engine/kaldi/util/kaldi-thread.h diff --git a/speechx/speechx/kaldi/util/options-itf.h b/runtime/engine/kaldi/util/options-itf.h similarity index 100% rename from speechx/speechx/kaldi/util/options-itf.h rename to runtime/engine/kaldi/util/options-itf.h diff --git a/speechx/speechx/kaldi/util/parse-options.cc b/runtime/engine/kaldi/util/parse-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.cc rename to runtime/engine/kaldi/util/parse-options.cc diff --git a/speechx/speechx/kaldi/util/parse-options.h b/runtime/engine/kaldi/util/parse-options.h similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.h rename to runtime/engine/kaldi/util/parse-options.h diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.cc b/runtime/engine/kaldi/util/simple-io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.cc rename to runtime/engine/kaldi/util/simple-io-funcs.cc diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.h b/runtime/engine/kaldi/util/simple-io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.h rename to runtime/engine/kaldi/util/simple-io-funcs.h diff --git a/speechx/speechx/kaldi/util/simple-options.cc b/runtime/engine/kaldi/util/simple-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.cc rename to runtime/engine/kaldi/util/simple-options.cc diff --git a/speechx/speechx/kaldi/util/simple-options.h b/runtime/engine/kaldi/util/simple-options.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.h rename to runtime/engine/kaldi/util/simple-options.h diff --git a/speechx/speechx/kaldi/util/stl-utils.h b/runtime/engine/kaldi/util/stl-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/stl-utils.h rename to runtime/engine/kaldi/util/stl-utils.h diff --git a/speechx/speechx/kaldi/util/table-types.h b/runtime/engine/kaldi/util/table-types.h similarity index 100% rename from speechx/speechx/kaldi/util/table-types.h rename to runtime/engine/kaldi/util/table-types.h diff --git a/speechx/speechx/kaldi/util/text-utils.cc b/runtime/engine/kaldi/util/text-utils.cc similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.cc rename to runtime/engine/kaldi/util/text-utils.cc diff --git a/speechx/speechx/kaldi/util/text-utils.h b/runtime/engine/kaldi/util/text-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.h rename to runtime/engine/kaldi/util/text-utils.h diff --git a/speechx/examples/.gitignore b/runtime/examples/.gitignore similarity index 80% rename from speechx/examples/.gitignore rename to runtime/examples/.gitignore index b7075fa5..38290f34 100644 --- a/speechx/examples/.gitignore +++ b/runtime/examples/.gitignore @@ -1,2 +1,3 @@ *.ark +*.scp paddle_asr_model/ diff --git a/speechx/examples/README.md b/runtime/examples/README.md similarity index 100% rename from speechx/examples/README.md rename to runtime/examples/README.md diff --git a/speechx/examples/codelab/README.md b/runtime/examples/codelab/README.md similarity index 100% rename from speechx/examples/codelab/README.md rename to runtime/examples/codelab/README.md diff --git a/speechx/examples/codelab/decoder/.gitignore b/runtime/examples/codelab/decoder/.gitignore similarity index 100% rename from speechx/examples/codelab/decoder/.gitignore rename to runtime/examples/codelab/decoder/.gitignore diff --git a/speechx/examples/codelab/decoder/README.md b/runtime/examples/codelab/decoder/README.md similarity index 100% rename from speechx/examples/codelab/decoder/README.md rename to runtime/examples/codelab/decoder/README.md diff --git a/speechx/examples/codelab/decoder/path.sh b/runtime/examples/codelab/decoder/path.sh similarity index 100% rename from speechx/examples/codelab/decoder/path.sh rename to runtime/examples/codelab/decoder/path.sh diff --git a/speechx/examples/codelab/decoder/run.sh b/runtime/examples/codelab/decoder/run.sh similarity index 100% rename from speechx/examples/codelab/decoder/run.sh rename to runtime/examples/codelab/decoder/run.sh diff --git a/speechx/examples/codelab/decoder/valgrind.sh b/runtime/examples/codelab/decoder/valgrind.sh similarity index 100% rename from speechx/examples/codelab/decoder/valgrind.sh rename to runtime/examples/codelab/decoder/valgrind.sh diff --git a/speechx/examples/codelab/feat/.gitignore b/runtime/examples/codelab/feat/.gitignore similarity index 100% rename from speechx/examples/codelab/feat/.gitignore rename to runtime/examples/codelab/feat/.gitignore diff --git a/speechx/examples/codelab/feat/README.md b/runtime/examples/codelab/feat/README.md similarity index 100% rename from speechx/examples/codelab/feat/README.md rename to runtime/examples/codelab/feat/README.md diff --git a/speechx/examples/codelab/feat/path.sh b/runtime/examples/codelab/feat/path.sh similarity index 100% rename from speechx/examples/codelab/feat/path.sh rename to runtime/examples/codelab/feat/path.sh diff --git a/speechx/examples/codelab/feat/run.sh b/runtime/examples/codelab/feat/run.sh similarity index 100% rename from speechx/examples/codelab/feat/run.sh rename to runtime/examples/codelab/feat/run.sh diff --git a/speechx/examples/codelab/feat/valgrind.sh b/runtime/examples/codelab/feat/valgrind.sh similarity index 100% rename from speechx/examples/codelab/feat/valgrind.sh rename to runtime/examples/codelab/feat/valgrind.sh diff --git a/speechx/examples/codelab/nnet/.gitignore b/runtime/examples/codelab/nnet/.gitignore similarity index 100% rename from speechx/examples/codelab/nnet/.gitignore rename to runtime/examples/codelab/nnet/.gitignore diff --git a/speechx/examples/codelab/nnet/README.md b/runtime/examples/codelab/nnet/README.md similarity index 100% rename from speechx/examples/codelab/nnet/README.md rename to runtime/examples/codelab/nnet/README.md diff --git a/speechx/examples/codelab/nnet/path.sh b/runtime/examples/codelab/nnet/path.sh similarity index 100% rename from speechx/examples/codelab/nnet/path.sh rename to runtime/examples/codelab/nnet/path.sh diff --git a/speechx/examples/codelab/nnet/run.sh b/runtime/examples/codelab/nnet/run.sh similarity index 100% rename from speechx/examples/codelab/nnet/run.sh rename to runtime/examples/codelab/nnet/run.sh diff --git a/speechx/examples/codelab/nnet/valgrind.sh b/runtime/examples/codelab/nnet/valgrind.sh similarity index 100% rename from speechx/examples/codelab/nnet/valgrind.sh rename to runtime/examples/codelab/nnet/valgrind.sh diff --git a/speechx/examples/codelab/u2/.gitignore b/runtime/examples/codelab/u2/.gitignore similarity index 100% rename from speechx/examples/codelab/u2/.gitignore rename to runtime/examples/codelab/u2/.gitignore diff --git a/speechx/examples/codelab/u2/README.md b/runtime/examples/codelab/u2/README.md similarity index 100% rename from speechx/examples/codelab/u2/README.md rename to runtime/examples/codelab/u2/README.md diff --git a/speechx/examples/codelab/u2/local/decode.sh b/runtime/examples/codelab/u2/local/decode.sh similarity index 100% rename from speechx/examples/codelab/u2/local/decode.sh rename to runtime/examples/codelab/u2/local/decode.sh diff --git a/speechx/examples/codelab/u2/local/feat.sh b/runtime/examples/codelab/u2/local/feat.sh similarity index 100% rename from speechx/examples/codelab/u2/local/feat.sh rename to runtime/examples/codelab/u2/local/feat.sh diff --git a/speechx/examples/codelab/u2/local/nnet.sh b/runtime/examples/codelab/u2/local/nnet.sh similarity index 100% rename from speechx/examples/codelab/u2/local/nnet.sh rename to runtime/examples/codelab/u2/local/nnet.sh diff --git a/speechx/examples/codelab/u2/local/recognizer.sh b/runtime/examples/codelab/u2/local/recognizer.sh similarity index 100% rename from speechx/examples/codelab/u2/local/recognizer.sh rename to runtime/examples/codelab/u2/local/recognizer.sh diff --git a/speechx/examples/codelab/u2/path.sh b/runtime/examples/codelab/u2/path.sh similarity index 100% rename from speechx/examples/codelab/u2/path.sh rename to runtime/examples/codelab/u2/path.sh diff --git a/speechx/examples/codelab/u2/run.sh b/runtime/examples/codelab/u2/run.sh similarity index 100% rename from speechx/examples/codelab/u2/run.sh rename to runtime/examples/codelab/u2/run.sh diff --git a/speechx/examples/codelab/u2/utils b/runtime/examples/codelab/u2/utils similarity index 100% rename from speechx/examples/codelab/u2/utils rename to runtime/examples/codelab/u2/utils diff --git a/speechx/examples/custom_asr/README.md b/runtime/examples/custom_asr/README.md similarity index 100% rename from speechx/examples/custom_asr/README.md rename to runtime/examples/custom_asr/README.md diff --git a/speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh b/runtime/examples/custom_asr/local/compile_lexicon_token_fst.sh similarity index 100% rename from speechx/examples/custom_asr/local/compile_lexicon_token_fst.sh rename to runtime/examples/custom_asr/local/compile_lexicon_token_fst.sh diff --git a/speechx/examples/custom_asr/local/mk_slot_graph.sh b/runtime/examples/custom_asr/local/mk_slot_graph.sh similarity index 100% rename from speechx/examples/custom_asr/local/mk_slot_graph.sh rename to runtime/examples/custom_asr/local/mk_slot_graph.sh diff --git a/speechx/examples/custom_asr/local/mk_tlg_with_slot.sh b/runtime/examples/custom_asr/local/mk_tlg_with_slot.sh similarity index 100% rename from speechx/examples/custom_asr/local/mk_tlg_with_slot.sh rename to runtime/examples/custom_asr/local/mk_tlg_with_slot.sh diff --git a/speechx/examples/custom_asr/local/train_lm_with_slot.sh b/runtime/examples/custom_asr/local/train_lm_with_slot.sh similarity index 100% rename from speechx/examples/custom_asr/local/train_lm_with_slot.sh rename to runtime/examples/custom_asr/local/train_lm_with_slot.sh diff --git a/speechx/examples/custom_asr/path.sh b/runtime/examples/custom_asr/path.sh similarity index 70% rename from speechx/examples/custom_asr/path.sh rename to runtime/examples/custom_asr/path.sh index 1907c79f..3f5dd476 100644 --- a/speechx/examples/custom_asr/path.sh +++ b/runtime/examples/custom_asr/path.sh @@ -1,8 +1,8 @@ # This contains the locations of binarys build required for running the examples. MAIN_ROOT=`realpath $PWD/../../../` -SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` -SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples +RUNTIME_ROOT=`realpath $MAIN_ROOT/runtime` +RUNTIME_EXAMPLES=$RUNTIME_ROOT/build/examples export LC_AL=C @@ -12,6 +12,6 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs export SRILM=${MAIN_ROOT}/tools/srilm # kaldi lm -KALDI_DIR=$SPEECHX_ROOT/build/speechx/kaldi/ -OPENFST_DIR=$SPEECHX_ROOT/fc_patch/openfst-build/src +KALDI_DIR=$RUNTIME_ROOT/build/engine/kaldi/ +OPENFST_DIR=$RUNTIME_ROOT/fc_patch/openfst-build/src export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64:$KALDI_DIR/lmbin:$KALDI_DIR/fstbin:$OPENFST_DIR/bin:$SPEECHX_EXAMPLES/ds2_ol/decoder diff --git a/speechx/examples/custom_asr/run.sh b/runtime/examples/custom_asr/run.sh similarity index 100% rename from speechx/examples/custom_asr/run.sh rename to runtime/examples/custom_asr/run.sh diff --git a/speechx/examples/custom_asr/utils b/runtime/examples/custom_asr/utils similarity index 100% rename from speechx/examples/custom_asr/utils rename to runtime/examples/custom_asr/utils diff --git a/speechx/examples/text_lm/.gitignore b/runtime/examples/text_lm/.gitignore similarity index 100% rename from speechx/examples/text_lm/.gitignore rename to runtime/examples/text_lm/.gitignore diff --git a/speechx/examples/text_lm/README.md b/runtime/examples/text_lm/README.md similarity index 100% rename from speechx/examples/text_lm/README.md rename to runtime/examples/text_lm/README.md diff --git a/speechx/examples/text_lm/local/data/chars.dic b/runtime/examples/text_lm/local/data/chars.dic similarity index 100% rename from speechx/examples/text_lm/local/data/chars.dic rename to runtime/examples/text_lm/local/data/chars.dic diff --git a/speechx/examples/text_lm/local/data/words.dic b/runtime/examples/text_lm/local/data/words.dic similarity index 100% rename from speechx/examples/text_lm/local/data/words.dic rename to runtime/examples/text_lm/local/data/words.dic diff --git a/speechx/examples/text_lm/local/mmseg.py b/runtime/examples/text_lm/local/mmseg.py similarity index 100% rename from speechx/examples/text_lm/local/mmseg.py rename to runtime/examples/text_lm/local/mmseg.py diff --git a/runtime/examples/text_lm/path.sh b/runtime/examples/text_lm/path.sh new file mode 100644 index 00000000..dc8fc8dd --- /dev/null +++ b/runtime/examples/text_lm/path.sh @@ -0,0 +1,4 @@ +MAIN_ROOT=`realpath $PWD/../../../` +ENGINE_ROOT=`realpath $MAIN_ROOT/runtime` + +export LC_AL=C diff --git a/speechx/examples/text_lm/run.sh b/runtime/examples/text_lm/run.sh similarity index 100% rename from speechx/examples/text_lm/run.sh rename to runtime/examples/text_lm/run.sh diff --git a/speechx/examples/text_lm/utils b/runtime/examples/text_lm/utils similarity index 100% rename from speechx/examples/text_lm/utils rename to runtime/examples/text_lm/utils diff --git a/speechx/examples/u2pp_ol/README.md b/runtime/examples/u2pp_ol/README.md similarity index 100% rename from speechx/examples/u2pp_ol/README.md rename to runtime/examples/u2pp_ol/README.md diff --git a/speechx/examples/u2pp_ol/wenetspeech/.gitignore b/runtime/examples/u2pp_ol/wenetspeech/.gitignore similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/.gitignore rename to runtime/examples/u2pp_ol/wenetspeech/.gitignore diff --git a/speechx/examples/u2pp_ol/wenetspeech/README.md b/runtime/examples/u2pp_ol/wenetspeech/README.md similarity index 90% rename from speechx/examples/u2pp_ol/wenetspeech/README.md rename to runtime/examples/u2pp_ol/wenetspeech/README.md index 6999fe3c..d66aacc7 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/README.md +++ b/runtime/examples/u2pp_ol/wenetspeech/README.md @@ -50,10 +50,10 @@ This stage using `u2_recognizer_main` to recognize wav file. The input is `scp` file which look like this: ```text # head data/split1/1/aishell_test.scp -BAC009S0764W0121 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav -BAC009S0764W0122 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav +BAC009S0764W0121 /workspace/PaddleSpeech/runtime/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0121.wav +BAC009S0764W0122 /workspace/PaddleSpeech/runtime/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0122.wav ... -BAC009S0764W0125 /workspace/PaddleSpeech/speechx/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav +BAC009S0764W0125 /workspace/PaddleSpeech/runtime/examples/u2pp_ol/wenetspeech/data/test/S0764/BAC009S0764W0125.wav ``` If you want to recognize one wav, you can make `scp` file like this: diff --git a/speechx/examples/u2pp_ol/wenetspeech/RESULTS.md b/runtime/examples/u2pp_ol/wenetspeech/RESULTS.md similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/RESULTS.md rename to runtime/examples/u2pp_ol/wenetspeech/RESULTS.md diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh b/runtime/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/aishell_train_lms.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/decode.sh b/runtime/examples/u2pp_ol/wenetspeech/local/decode.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/decode.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/decode.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/feat.sh b/runtime/examples/u2pp_ol/wenetspeech/local/feat.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/feat.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/feat.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh b/runtime/examples/u2pp_ol/wenetspeech/local/nnet.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/nnet.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/nnet.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/recognizer.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/recognizer.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh b/runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/recognizer_quant.sh diff --git a/runtime/examples/u2pp_ol/wenetspeech/local/run_build_tlg.sh b/runtime/examples/u2pp_ol/wenetspeech/local/run_build_tlg.sh new file mode 100755 index 00000000..30ea2020 --- /dev/null +++ b/runtime/examples/u2pp_ol/wenetspeech/local/run_build_tlg.sh @@ -0,0 +1,84 @@ +#!/bin/bash +set -eo pipefail + +#. path.sh + +# attention, please replace the vocab is only for this script. +# different acustic model has different vocab +ckpt_dir=data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.3.0.model +unit=$ckpt_dir/vocab.txt # vocab file, line: char/spm_pice +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ + +stage=2 +stop_stage=100 +corpus=aishell +lexicon=data/lexicon.txt # line: word ph0 ... phn, aishell/resource_aishell/lexicon.txt +text=data/text # line: utt text, aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + +. utils/parse_options.sh + +data=$PWD/data +mkdir -p $data + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + if [ ! -f $data/speech.ngram.zh.tar.gz ];then + # download ngram + pushd $data + wget -c http://paddlespeech.bj.bcebos.com/speechx/examples/ngram/zh/speech.ngram.zh.tar.gz + tar xvzf speech.ngram.zh.tar.gz + popd + fi +fi + +if [ ! -f $unit ]; then + echo "$0: No such file $unit" + exit 1; +fi + +if ! which ngram-count; then + # need srilm install + pushd $MAIN_ROOT/tools + make srilm.done + popd +fi + +echo "done." +mkdir -p data/local/dict +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Prepare dict + # line: char/spm_pices + cp $unit data/local/dict/units.txt + + if [ ! -f $lexicon ];then + utils/text_to_lexicon.py --has_key true --text $text --lexicon $lexicon + echo "Generate $lexicon from $text" + fi + + # filter by vocab + # line: word ph0 ... phn -> line: word char0 ... charn + utils/fst/prepare_dict.py \ + --unit_file $unit \ + --in_lexicon ${lexicon} \ + --out_lexicon data/local/dict/lexicon.txt +fi + +lm=data/local/lm +mkdir -p $lm + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # Train ngram lm + cp $text $lm/text + local/aishell_train_lms.sh + echo "build LM done." +fi + +# build TLG +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # build T & L + utils/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + + # build G & TLG + utils/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + +fi diff --git a/speechx/examples/u2pp_ol/wenetspeech/local/split_data.sh b/runtime/examples/u2pp_ol/wenetspeech/local/split_data.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/local/split_data.sh rename to runtime/examples/u2pp_ol/wenetspeech/local/split_data.sh diff --git a/runtime/examples/u2pp_ol/wenetspeech/path.sh b/runtime/examples/u2pp_ol/wenetspeech/path.sh new file mode 100644 index 00000000..ad3a7358 --- /dev/null +++ b/runtime/examples/u2pp_ol/wenetspeech/path.sh @@ -0,0 +1,18 @@ +# This contains the locations of binarys build required for running the examples. + +unset GREP_OPTIONS + +ENGINE_ROOT=$PWD/../../../ +ENGINE_BUILD=$ENGINE_ROOT/build/engine/asr + +ENGINE_TOOLS=$ENGINE_ROOT/tools +TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin + +[ -d $ENGINE_BUILD ] || { echo "Error: 'build/runtime' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/decoder:$ENGINE_BUILD/../common/frontend/audio:$ENGINE_BUILD/recognizer + +#PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") +export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/examples/u2pp_ol/wenetspeech/run.sh b/runtime/examples/u2pp_ol/wenetspeech/run.sh similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/run.sh rename to runtime/examples/u2pp_ol/wenetspeech/run.sh diff --git a/speechx/examples/u2pp_ol/wenetspeech/utils b/runtime/examples/u2pp_ol/wenetspeech/utils similarity index 100% rename from speechx/examples/u2pp_ol/wenetspeech/utils rename to runtime/examples/u2pp_ol/wenetspeech/utils diff --git a/speechx/patch/CPPLINT.cfg b/runtime/patch/CPPLINT.cfg similarity index 100% rename from speechx/patch/CPPLINT.cfg rename to runtime/patch/CPPLINT.cfg diff --git a/speechx/patch/README.md b/runtime/patch/README.md similarity index 100% rename from speechx/patch/README.md rename to runtime/patch/README.md diff --git a/speechx/patch/openfst/src/include/fst/flags.h b/runtime/patch/openfst/src/include/fst/flags.h similarity index 100% rename from speechx/patch/openfst/src/include/fst/flags.h rename to runtime/patch/openfst/src/include/fst/flags.h diff --git a/speechx/patch/openfst/src/include/fst/log.h b/runtime/patch/openfst/src/include/fst/log.h similarity index 100% rename from speechx/patch/openfst/src/include/fst/log.h rename to runtime/patch/openfst/src/include/fst/log.h diff --git a/speechx/patch/openfst/src/lib/flags.cc b/runtime/patch/openfst/src/lib/flags.cc similarity index 100% rename from speechx/patch/openfst/src/lib/flags.cc rename to runtime/patch/openfst/src/lib/flags.cc diff --git a/speechx/tools/clang-format.sh b/runtime/tools/clang-format.sh similarity index 100% rename from speechx/tools/clang-format.sh rename to runtime/tools/clang-format.sh diff --git a/speechx/tools/setup_valgrind.sh b/runtime/tools/setup_valgrind.sh similarity index 100% rename from speechx/tools/setup_valgrind.sh rename to runtime/tools/setup_valgrind.sh diff --git a/speechx/tools/venv.sh b/runtime/tools/venv.sh similarity index 100% rename from speechx/tools/venv.sh rename to runtime/tools/venv.sh diff --git a/speechx/examples/text_lm/path.sh b/speechx/examples/text_lm/path.sh deleted file mode 100644 index 541f852c..00000000 --- a/speechx/examples/text_lm/path.sh +++ /dev/null @@ -1,4 +0,0 @@ -MAIN_ROOT=`realpath $PWD/../../../../` -SPEECHX_ROOT=`realpath $MAIN_ROOT/speechx` - -export LC_AL=C diff --git a/speechx/examples/u2pp_ol/wenetspeech/path.sh b/speechx/examples/u2pp_ol/wenetspeech/path.sh deleted file mode 100644 index 9518db11..00000000 --- a/speechx/examples/u2pp_ol/wenetspeech/path.sh +++ /dev/null @@ -1,18 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -unset GREP_OPTIONS - -SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx/asr - -SPEECHX_TOOLS=$SPEECHX_ROOT/tools -TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin - -[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/../common/frontend/audio:$SPEECHX_BUILD/recognizer - -PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") -export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/speechx/common/frontend/wave-reader.cc b/speechx/speechx/common/frontend/wave-reader.cc deleted file mode 100644 index 42bf79c6..00000000 --- a/speechx/speechx/common/frontend/wave-reader.cc +++ /dev/null @@ -1,387 +0,0 @@ -// feat/wave-reader.cc - -// Copyright 2009-2011 Karel Vesely; Petr Motlicek -// 2013 Florent Masson -// 2013 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "frontend/wave-reader.h" -#include "base/kaldi-error.h" -#include "base/kaldi-utils.h" - -namespace kaldi { - -// A utility class for reading wave header. -struct WaveHeaderReadGofer { - std::istream &is; - bool swap; - char tag[5]; - - WaveHeaderReadGofer(std::istream &is) : is(is), swap(false) { - memset(tag, '\0', sizeof tag); - } - - void Expect4ByteTag(const char *expected) { - is.read(tag, 4); - if (is.fail()) - KALDI_ERR << "WaveData: expected " << expected - << ", failed to read anything"; - if (strcmp(tag, expected)) - KALDI_ERR << "WaveData: expected " << expected << ", got " << tag; - } - - void Read4ByteTag() { - is.read(tag, 4); - if (is.fail()) - KALDI_ERR << "WaveData: expected 4-byte chunk-name, got read error"; - } - - uint32 ReadUint32() { - union { - char result[4]; - uint32 ans; - } u; - is.read(u.result, 4); - if (swap) - KALDI_SWAP4(u.result); - if (is.fail()) - KALDI_ERR << "WaveData: unexpected end of file or read error"; - return u.ans; - } - - uint16 ReadUint16() { - union { - char result[2]; - int16 ans; - } u; - is.read(u.result, 2); - if (swap) - KALDI_SWAP2(u.result); - if (is.fail()) - KALDI_ERR << "WaveData: unexpected end of file or read error"; - return u.ans; - } -}; - -static void WriteUint32(std::ostream &os, int32 i) { - union { - char buf[4]; - int i; - } u; - u.i = i; -#ifdef __BIG_ENDIAN__ - KALDI_SWAP4(u.buf); -#endif - os.write(u.buf, 4); - if (os.fail()) - KALDI_ERR << "WaveData: error writing to stream."; -} - -static void WriteUint16(std::ostream &os, int16 i) { - union { - char buf[2]; - int16 i; - } u; - u.i = i; -#ifdef __BIG_ENDIAN__ - KALDI_SWAP2(u.buf); -#endif - os.write(u.buf, 2); - if (os.fail()) - KALDI_ERR << "WaveData: error writing to stream."; -} - -void WaveInfo::Read(std::istream &is) { - WaveHeaderReadGofer reader(is); - reader.Read4ByteTag(); - if (strcmp(reader.tag, "RIFF") == 0) - reverse_bytes_ = false; - else if (strcmp(reader.tag, "RIFX") == 0) - reverse_bytes_ = true; - else - KALDI_ERR << "WaveData: expected RIFF or RIFX, got " << reader.tag; - -#ifdef __BIG_ENDIAN__ - reverse_bytes_ = !reverse_bytes_; -#endif - reader.swap = reverse_bytes_; - - uint32 riff_chunk_size = reader.ReadUint32(); - reader.Expect4ByteTag("WAVE"); - - uint32 riff_chunk_read = 0; - riff_chunk_read += 4; // WAVE included in riff_chunk_size. - - // Possibly skip any RIFF tags between 'WAVE' and 'fmt '. - // Apple devices produce a filler tag 'JUNK' for memory alignment. - reader.Read4ByteTag(); - riff_chunk_read += 4; - while (strcmp(reader.tag,"fmt ") != 0) { - uint32 filler_size = reader.ReadUint32(); - riff_chunk_read += 4; - for (uint32 i = 0; i < filler_size; i++) { - is.get(); // read 1 byte, - } - riff_chunk_read += filler_size; - // get next RIFF tag, - reader.Read4ByteTag(); - riff_chunk_read += 4; - } - - KALDI_ASSERT(strcmp(reader.tag,"fmt ") == 0); - uint32 subchunk1_size = reader.ReadUint32(); - uint16 audio_format = reader.ReadUint16(); - num_channels_ = reader.ReadUint16(); - uint32 sample_rate = reader.ReadUint32(), - byte_rate = reader.ReadUint32(), - block_align = reader.ReadUint16(), - bits_per_sample = reader.ReadUint16(); - samp_freq_ = static_cast(sample_rate); - - uint32 fmt_chunk_read = 16; - if (audio_format == 1) { - if (subchunk1_size < 16) { - KALDI_ERR << "WaveData: expect PCM format data to have fmt chunk " - << "of at least size 16."; - } - } else if (audio_format == 0xFFFE) { // WAVE_FORMAT_EXTENSIBLE - uint16 extra_size = reader.ReadUint16(); - if (subchunk1_size < 40 || extra_size < 22) { - KALDI_ERR << "WaveData: malformed WAVE_FORMAT_EXTENSIBLE format data."; - } - reader.ReadUint16(); // Unused for PCM. - reader.ReadUint32(); // Channel map: we do not care. - uint32 guid1 = reader.ReadUint32(), - guid2 = reader.ReadUint32(), - guid3 = reader.ReadUint32(), - guid4 = reader.ReadUint32(); - fmt_chunk_read = 40; - - // Support only KSDATAFORMAT_SUBTYPE_PCM for now. Interesting formats: - // ("00000001-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_PCM) - // ("00000003-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) - // ("00000006-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_ALAW) - // ("00000007-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_MULAW) - if (guid1 != 0x00000001 || guid2 != 0x00100000 || - guid3 != 0xAA000080 || guid4 != 0x719B3800) { - KALDI_ERR << "WaveData: unsupported WAVE_FORMAT_EXTENSIBLE format."; - } - } else { - KALDI_ERR << "WaveData: can read only PCM data, format id in file is: " - << audio_format; - } - - for (uint32 i = fmt_chunk_read; i < subchunk1_size; ++i) - is.get(); // use up extra data. - - if (num_channels_ == 0) - KALDI_ERR << "WaveData: no channels present"; - if (bits_per_sample != 16) - KALDI_ERR << "WaveData: unsupported bits_per_sample = " << bits_per_sample; - if (byte_rate != sample_rate * bits_per_sample/8 * num_channels_) - KALDI_ERR << "Unexpected byte rate " << byte_rate << " vs. " - << sample_rate << " * " << (bits_per_sample/8) - << " * " << num_channels_; - if (block_align != num_channels_ * bits_per_sample/8) - KALDI_ERR << "Unexpected block_align: " << block_align << " vs. " - << num_channels_ << " * " << (bits_per_sample/8); - - riff_chunk_read += 4 + subchunk1_size; - // size of what we just read, 4 for subchunk1_size + subchunk1_size itself. - - // We support an optional "fact" chunk (which is useless but which - // we encountered), and then a single "data" chunk. - - reader.Read4ByteTag(); - riff_chunk_read += 4; - - // Skip any subchunks between "fmt" and "data". Usually there will - // be a single "fact" subchunk, but on Windows there can also be a - // "list" subchunk. - while (strcmp(reader.tag, "data") != 0) { - // We will just ignore the data in these chunks. - uint32 chunk_sz = reader.ReadUint32(); - if (chunk_sz != 4 && strcmp(reader.tag, "fact") == 0) - KALDI_WARN << "Expected fact chunk to be 4 bytes long."; - for (uint32 i = 0; i < chunk_sz; i++) - is.get(); - riff_chunk_read += 4 + chunk_sz; // for chunk_sz (4) + chunk contents (chunk-sz) - - // Now read the next chunk name. - reader.Read4ByteTag(); - riff_chunk_read += 4; - } - - KALDI_ASSERT(strcmp(reader.tag, "data") == 0); - uint32 data_chunk_size = reader.ReadUint32(); - riff_chunk_read += 4; - - // Figure out if the file is going to be read to the end. Values as - // observed in the wild: - bool is_stream_mode = - riff_chunk_size == 0 - || riff_chunk_size == 0xFFFFFFFF - || data_chunk_size == 0 - || data_chunk_size == 0xFFFFFFFF - || data_chunk_size == 0x7FFFF000; // This value is used by SoX. - - if (is_stream_mode) - KALDI_VLOG(1) << "Read in RIFF chunk size: " << riff_chunk_size - << ", data chunk size: " << data_chunk_size - << ". Assume 'stream mode' (reading data to EOF)."; - - if (!is_stream_mode - && std::abs(static_cast(riff_chunk_read) + - static_cast(data_chunk_size) - - static_cast(riff_chunk_size)) > 1) { - // We allow the size to be off by one without warning, because there is a - // weirdness in the format of RIFF files that means that the input may - // sometimes be padded with 1 unused byte to make the total size even. - KALDI_WARN << "Expected " << riff_chunk_size << " bytes in RIFF chunk, but " - << "after first data block there will be " << riff_chunk_read - << " + " << data_chunk_size << " bytes " - << "(we do not support reading multiple data chunks)."; - } - - if (is_stream_mode) - samp_count_ = -1; - else - samp_count_ = data_chunk_size / block_align; -} - -void WaveData::Read(std::istream &is) { - const uint32 kBlockSize = 1024 * 1024; - - WaveInfo header; - header.Read(is); - - data_.Resize(0, 0); // clear the data. - samp_freq_ = header.SampFreq(); - - std::vector buffer; - uint32 bytes_to_go = header.IsStreamed() ? kBlockSize : header.DataBytes(); - - // Once in a while header.DataBytes() will report an insane value; - // read the file to the end - while (is && bytes_to_go > 0) { - uint32 block_bytes = std::min(bytes_to_go, kBlockSize); - uint32 offset = buffer.size(); - buffer.resize(offset + block_bytes); - is.read(&buffer[offset], block_bytes); - uint32 bytes_read = is.gcount(); - buffer.resize(offset + bytes_read); - if (!header.IsStreamed()) - bytes_to_go -= bytes_read; - } - - if (is.bad()) - KALDI_ERR << "WaveData: file read error"; - - if (buffer.size() == 0) - KALDI_ERR << "WaveData: empty file (no data)"; - - if (!header.IsStreamed() && buffer.size() < header.DataBytes()) { - KALDI_WARN << "Expected " << header.DataBytes() << " bytes of wave data, " - << "but read only " << buffer.size() << " bytes. " - << "Truncated file?"; - } - - uint16 *data_ptr = reinterpret_cast(&buffer[0]); - - // The matrix is arranged row per channel, column per sample. - data_.Resize(header.NumChannels(), - buffer.size() / header.BlockAlign()); - for (uint32 i = 0; i < data_.NumCols(); ++i) { - for (uint32 j = 0; j < data_.NumRows(); ++j) { - int16 k = *data_ptr++; - if (header.ReverseBytes()) - KALDI_SWAP2(k); - data_(j, i) = k; - } - } -} - - -// Write 16-bit PCM. - -// note: the WAVE chunk contains 2 subchunks. -// -// subchunk2size = data.NumRows() * data.NumCols() * 2. - - -void WaveData::Write(std::ostream &os) const { - os << "RIFF"; - if (data_.NumRows() == 0) - KALDI_ERR << "Error: attempting to write empty WAVE file"; - - int32 num_chan = data_.NumRows(), - num_samp = data_.NumCols(), - bytes_per_samp = 2; - - int32 subchunk2size = (num_chan * num_samp * bytes_per_samp); - int32 chunk_size = 36 + subchunk2size; - WriteUint32(os, chunk_size); - os << "WAVE"; - os << "fmt "; - WriteUint32(os, 16); - WriteUint16(os, 1); - WriteUint16(os, num_chan); - KALDI_ASSERT(samp_freq_ > 0); - WriteUint32(os, static_cast(samp_freq_)); - WriteUint32(os, static_cast(samp_freq_) * num_chan * bytes_per_samp); - WriteUint16(os, num_chan * bytes_per_samp); - WriteUint16(os, 8 * bytes_per_samp); - os << "data"; - WriteUint32(os, subchunk2size); - - const BaseFloat *data_ptr = data_.Data(); - int32 stride = data_.Stride(); - - int num_clipped = 0; - for (int32 i = 0; i < num_samp; i++) { - for (int32 j = 0; j < num_chan; j++) { - int32 elem = static_cast(trunc(data_ptr[j * stride + i])); - int16 elem_16 = static_cast(elem); - if (elem < std::numeric_limits::min()) { - elem_16 = std::numeric_limits::min(); - ++num_clipped; - } else if (elem > std::numeric_limits::max()) { - elem_16 = std::numeric_limits::max(); - ++num_clipped; - } -#ifdef __BIG_ENDIAN__ - KALDI_SWAP2(elem_16); -#endif - os.write(reinterpret_cast(&elem_16), 2); - } - } - if (os.fail()) - KALDI_ERR << "Error writing wave data to stream."; - if (num_clipped > 0) - KALDI_WARN << "WARNING: clipped " << num_clipped - << " samples out of total " << num_chan * num_samp - << ". Reduce volume?"; -} - - -} // end namespace kaldi diff --git a/speechx/speechx/common/frontend/wave-reader.h b/speechx/speechx/common/frontend/wave-reader.h deleted file mode 100644 index dae74139..00000000 --- a/speechx/speechx/common/frontend/wave-reader.h +++ /dev/null @@ -1,248 +0,0 @@ -// feat/wave-reader.h - -// Copyright 2009-2011 Karel Vesely; Microsoft Corporation -// 2013 Florent Masson -// 2013 Johns Hopkins University (author: Daniel Povey) - -// See ../../COPYING for clarification regarding multiple authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -// MERCHANTABLITY OR NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing permissions and -// limitations under the License. - - -/* -// THE WAVE FORMAT IS SPECIFIED IN: -// https:// ccrma.stanford.edu/courses/422/projects/WaveFormat/ -// -// -// -// RIFF -// | -// WAVE -// | \ \ \ -// fmt_ data ... data -// -// -// Riff is a general container, which usually contains one WAVE chunk -// each WAVE chunk has header sub-chunk 'fmt_' -// and one or more data sub-chunks 'data' -// -// [Note from Dan: to say that the wave format was ever "specified" anywhere is -// not quite right. The guy who invented the wave format attempted to create -// a formal specification but it did not completely make sense. And there -// doesn't seem to be a consensus on what makes a valid wave file, -// particularly where the accuracy of header information is concerned.] -*/ - - -#ifndef KALDI_FEAT_WAVE_READER_H_ -#define KALDI_FEAT_WAVE_READER_H_ - -#include - -#include "base/kaldi-types.h" -#include "matrix/kaldi-vector.h" -#include "matrix/kaldi-matrix.h" - - -namespace kaldi { - -/// For historical reasons, we scale waveforms to the range -/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. -const BaseFloat kWaveSampleMax = 32768.0; - -/// This class reads and hold wave file header information. -class WaveInfo { - public: - WaveInfo() : samp_freq_(0), samp_count_(0), - num_channels_(0), reverse_bytes_(0) {} - - /// Is stream size unknown? Duration and SampleCount not valid if true. - bool IsStreamed() const { return samp_count_ < 0; } - - /// Sample frequency, Hz. - BaseFloat SampFreq() const { return samp_freq_; } - - /// Number of samples in stream. Invalid if IsStreamed() is true. - uint32 SampleCount() const { return samp_count_; } - - /// Approximate duration, seconds. Invalid if IsStreamed() is true. - BaseFloat Duration() const { return samp_count_ / samp_freq_; } - - /// Number of channels, 1 to 16. - int32 NumChannels() const { return num_channels_; } - - /// Bytes per sample. - size_t BlockAlign() const { return 2 * num_channels_; } - - /// Wave data bytes. Invalid if IsStreamed() is true. - size_t DataBytes() const { return samp_count_ * BlockAlign(); } - - /// Is data file byte order different from machine byte order? - bool ReverseBytes() const { return reverse_bytes_; } - - /// 'is' should be opened in binary mode. Read() will throw on error. - /// On success 'is' will be positioned at the beginning of wave data. - void Read(std::istream &is); - - private: - BaseFloat samp_freq_; - int32 samp_count_; // 0 if empty, -1 if undefined length. - uint8 num_channels_; - bool reverse_bytes_; // File endianness differs from host. -}; - -/// This class's purpose is to read in Wave files. -class WaveData { - public: - WaveData(BaseFloat samp_freq, const MatrixBase &data) - : data_(data), samp_freq_(samp_freq) {} - - WaveData() : samp_freq_(0.0) {} - - /// Read() will throw on error. It's valid to call Read() more than once-- - /// in this case it will destroy what was there before. - /// "is" should be opened in binary mode. - void Read(std::istream &is); - - /// Write() will throw on error. os should be opened in binary mode. - void Write(std::ostream &os) const; - - // This function returns the wave data-- it's in a matrix - // because there may be multiple channels. In the normal case - // there's just one channel so Data() will have one row. - const Matrix &Data() const { return data_; } - - BaseFloat SampFreq() const { return samp_freq_; } - - // Returns the duration in seconds - BaseFloat Duration() const { return data_.NumCols() / samp_freq_; } - - void CopyFrom(const WaveData &other) { - samp_freq_ = other.samp_freq_; - data_.CopyFromMat(other.data_); - } - - void Clear() { - data_.Resize(0, 0); - samp_freq_ = 0.0; - } - - void Swap(WaveData *other) { - data_.Swap(&(other->data_)); - std::swap(samp_freq_, other->samp_freq_); - } - - private: - static const uint32 kBlockSize = 1024 * 1024; // Use 1M bytes. - Matrix data_; - BaseFloat samp_freq_; -}; - - -// Holder class for .wav files that enables us to read (but not write) .wav -// files. c.f. util/kaldi-holder.h we don't use the KaldiObjectHolder template -// because we don't want to check for the \0B binary header. We could have faked -// it by pretending to read in the wave data in text mode after failing to find -// the \0B header, but that would have been a little ugly. -class WaveHolder { - public: - typedef WaveData T; - - static bool Write(std::ostream &os, bool binary, const T &t) { - // We don't write the binary-mode header here [always binary]. - if (!binary) - KALDI_ERR << "Wave data can only be written in binary mode."; - try { - t.Write(os); // throws exception on failure. - return true; - } catch (const std::exception &e) { - KALDI_WARN << "Exception caught in WaveHolder object (writing). " - << e.what(); - return false; // write failure. - } - } - void Copy(const T &t) { t_.CopyFrom(t); } - - static bool IsReadInBinary() { return true; } - - void Clear() { t_.Clear(); } - - T &Value() { return t_; } - - WaveHolder &operator = (const WaveHolder &other) { - t_.CopyFrom(other.t_); - return *this; - } - WaveHolder(const WaveHolder &other): t_(other.t_) {} - - WaveHolder() {} - - bool Read(std::istream &is) { - // We don't look for the binary-mode header here [always binary] - try { - t_.Read(is); // Throws exception on failure. - return true; - } catch (const std::exception &e) { - KALDI_WARN << "Exception caught in WaveHolder::Read(). " << e.what(); - return false; - } - } - - void Swap(WaveHolder *other) { - t_.Swap(&(other->t_)); - } - - bool ExtractRange(const WaveHolder &other, const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - - private: - T t_; -}; - -// This is like WaveHolder but when you just want the metadata- -// it leaves the actual data undefined, it doesn't read it. -class WaveInfoHolder { - public: - typedef WaveInfo T; - - void Clear() { info_ = WaveInfo(); } - void Swap(WaveInfoHolder *other) { std::swap(info_, other->info_); } - T &Value() { return info_; } - static bool IsReadInBinary() { return true; } - - bool Read(std::istream &is) { - try { - info_.Read(is); // Throws exception on failure. - return true; - } catch (const std::exception &e) { - KALDI_WARN << "Exception caught in WaveInfoHolder::Read(). " << e.what(); - return false; - } - } - - bool ExtractRange(const WaveInfoHolder &other, const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - - private: - WaveInfo info_; -}; - - -} // namespace kaldi - -#endif // KALDI_FEAT_WAVE_READER_H_ diff --git a/speechx/speechx/common/utils/picojson.h b/speechx/speechx/common/utils/picojson.h deleted file mode 100644 index 28c5b7fa..00000000 --- a/speechx/speechx/common/utils/picojson.h +++ /dev/null @@ -1,1202 +0,0 @@ -/* - * Copyright 2009-2010 Cybozu Labs, Inc. - * Copyright 2011-2014 Kazuho Oku - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#ifndef picojson_h -#define picojson_h - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define PICOJSON_USE_INT64 1 - -// for isnan/isinf -#if __cplusplus >= 201103L -#include -#else -extern "C" { -#ifdef _MSC_VER -#include -#elif defined(__INTEL_COMPILER) -#include -#else -#include -#endif -} -#endif - -#ifndef PICOJSON_USE_RVALUE_REFERENCE -#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600) -#define PICOJSON_USE_RVALUE_REFERENCE 1 -#else -#define PICOJSON_USE_RVALUE_REFERENCE 0 -#endif -#endif // PICOJSON_USE_RVALUE_REFERENCE - -#ifndef PICOJSON_NOEXCEPT -#if PICOJSON_USE_RVALUE_REFERENCE -#define PICOJSON_NOEXCEPT noexcept -#else -#define PICOJSON_NOEXCEPT throw() -#endif -#endif - -// experimental support for int64_t (see README.mkdn for detail) -#ifdef PICOJSON_USE_INT64 -#define __STDC_FORMAT_MACROS -#include -#if __cplusplus >= 201103L -#include -#else -extern "C" { -#include -} -#endif -#endif - -// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 -#ifndef PICOJSON_USE_LOCALE -#define PICOJSON_USE_LOCALE 1 -#endif -#if PICOJSON_USE_LOCALE -extern "C" { -#include -} -#endif - -#ifndef PICOJSON_ASSERT -#define PICOJSON_ASSERT(e) \ - do { \ - if (!(e)) \ - throw std::runtime_error(#e); \ - } while (0) -#endif - -#ifdef _MSC_VER -#define SNPRINTF _snprintf_s -#pragma warning(push) -#pragma warning(disable : 4244) // conversion from int to char -#pragma warning(disable : 4127) // conditional expression is constant -#pragma warning(disable : 4702) // unreachable code -#pragma warning(disable : 4706) // assignment within conditional expression -#else -#define SNPRINTF snprintf -#endif - -namespace picojson { - -enum { - null_type, - boolean_type, - number_type, - string_type, - array_type, - object_type -#ifdef PICOJSON_USE_INT64 - , - int64_type -#endif -}; - -enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; - -struct null {}; - -class value { -public: - typedef std::vector array; - typedef std::map object; - union _storage { - bool boolean_; - double number_; -#ifdef PICOJSON_USE_INT64 - int64_t int64_; -#endif - std::string *string_; - array *array_; - object *object_; - }; - -protected: - int type_; - _storage u_; - -public: - value(); - value(int type, bool); - explicit value(bool b); -#ifdef PICOJSON_USE_INT64 - explicit value(int64_t i); -#endif - explicit value(double n); - explicit value(const std::string &s); - explicit value(const array &a); - explicit value(const object &o); -#if PICOJSON_USE_RVALUE_REFERENCE - explicit value(std::string &&s); - explicit value(array &&a); - explicit value(object &&o); -#endif - explicit value(const char *s); - value(const char *s, size_t len); - ~value(); - value(const value &x); - value &operator=(const value &x); -#if PICOJSON_USE_RVALUE_REFERENCE - value(value &&x) PICOJSON_NOEXCEPT; - value &operator=(value &&x) PICOJSON_NOEXCEPT; -#endif - void swap(value &x) PICOJSON_NOEXCEPT; - template bool is() const; - template const T &get() const; - template T &get(); - template void set(const T &); -#if PICOJSON_USE_RVALUE_REFERENCE - template void set(T &&); -#endif - bool evaluate_as_boolean() const; - const value &get(const size_t idx) const; - const value &get(const std::string &key) const; - value &get(const size_t idx); - value &get(const std::string &key); - - bool contains(const size_t idx) const; - bool contains(const std::string &key) const; - std::string to_str() const; - template void serialize(Iter os, bool prettify = false) const; - std::string serialize(bool prettify = false) const; - -private: - template value(const T *); // intentionally defined to block implicit conversion of pointer to bool - template static void _indent(Iter os, int indent); - template void _serialize(Iter os, int indent) const; - std::string _serialize(int indent) const; - void clear(); -}; - -typedef value::array array; -typedef value::object object; - -inline value::value() : type_(null_type), u_() { -} - -inline value::value(int type, bool) : type_(type), u_() { - switch (type) { -#define INIT(p, v) \ - case p##type: \ - u_.p = v; \ - break - INIT(boolean_, false); - INIT(number_, 0.0); -#ifdef PICOJSON_USE_INT64 - INIT(int64_, 0); -#endif - INIT(string_, new std::string()); - INIT(array_, new array()); - INIT(object_, new object()); -#undef INIT - default: - break; - } -} - -inline value::value(bool b) : type_(boolean_type), u_() { - u_.boolean_ = b; -} - -#ifdef PICOJSON_USE_INT64 -inline value::value(int64_t i) : type_(int64_type), u_() { - u_.int64_ = i; -} -#endif - -inline value::value(double n) : type_(number_type), u_() { - if ( -#ifdef _MSC_VER - !_finite(n) -#elif __cplusplus >= 201103L - std::isnan(n) || std::isinf(n) -#else - isnan(n) || isinf(n) -#endif - ) { - throw std::overflow_error(""); - } - u_.number_ = n; -} - -inline value::value(const std::string &s) : type_(string_type), u_() { - u_.string_ = new std::string(s); -} - -inline value::value(const array &a) : type_(array_type), u_() { - u_.array_ = new array(a); -} - -inline value::value(const object &o) : type_(object_type), u_() { - u_.object_ = new object(o); -} - -#if PICOJSON_USE_RVALUE_REFERENCE -inline value::value(std::string &&s) : type_(string_type), u_() { - u_.string_ = new std::string(std::move(s)); -} - -inline value::value(array &&a) : type_(array_type), u_() { - u_.array_ = new array(std::move(a)); -} - -inline value::value(object &&o) : type_(object_type), u_() { - u_.object_ = new object(std::move(o)); -} -#endif - -inline value::value(const char *s) : type_(string_type), u_() { - u_.string_ = new std::string(s); -} - -inline value::value(const char *s, size_t len) : type_(string_type), u_() { - u_.string_ = new std::string(s, len); -} - -inline void value::clear() { - switch (type_) { -#define DEINIT(p) \ - case p##type: \ - delete u_.p; \ - break - DEINIT(string_); - DEINIT(array_); - DEINIT(object_); -#undef DEINIT - default: - break; - } -} - -inline value::~value() { - clear(); -} - -inline value::value(const value &x) : type_(x.type_), u_() { - switch (type_) { -#define INIT(p, v) \ - case p##type: \ - u_.p = v; \ - break - INIT(string_, new std::string(*x.u_.string_)); - INIT(array_, new array(*x.u_.array_)); - INIT(object_, new object(*x.u_.object_)); -#undef INIT - default: - u_ = x.u_; - break; - } -} - -inline value &value::operator=(const value &x) { - if (this != &x) { - value t(x); - swap(t); - } - return *this; -} - -#if PICOJSON_USE_RVALUE_REFERENCE -inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { - swap(x); -} -inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { - swap(x); - return *this; -} -#endif -inline void value::swap(value &x) PICOJSON_NOEXCEPT { - std::swap(type_, x.type_); - std::swap(u_, x.u_); -} - -#define IS(ctype, jtype) \ - template <> inline bool value::is() const { \ - return type_ == jtype##_type; \ - } -IS(null, null) -IS(bool, boolean) -#ifdef PICOJSON_USE_INT64 -IS(int64_t, int64) -#endif -IS(std::string, string) -IS(array, array) -IS(object, object) -#undef IS -template <> inline bool value::is() const { - return type_ == number_type -#ifdef PICOJSON_USE_INT64 - || type_ == int64_type -#endif - ; -} - -#define GET(ctype, var) \ - template <> inline const ctype &value::get() const { \ - PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ - return var; \ - } \ - template <> inline ctype &value::get() { \ - PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ - return var; \ - } -GET(bool, u_.boolean_) -GET(std::string, *u_.string_) -GET(array, *u_.array_) -GET(object, *u_.object_) -#ifdef PICOJSON_USE_INT64 -GET(double, - (type_ == int64_type && (const_cast(this)->type_ = number_type, (const_cast(this)->u_.number_ = u_.int64_)), - u_.number_)) -GET(int64_t, u_.int64_) -#else -GET(double, u_.number_) -#endif -#undef GET - -#define SET(ctype, jtype, setter) \ - template <> inline void value::set(const ctype &_val) { \ - clear(); \ - type_ = jtype##_type; \ - setter \ - } -SET(bool, boolean, u_.boolean_ = _val;) -SET(std::string, string, u_.string_ = new std::string(_val);) -SET(array, array, u_.array_ = new array(_val);) -SET(object, object, u_.object_ = new object(_val);) -SET(double, number, u_.number_ = _val;) -#ifdef PICOJSON_USE_INT64 -SET(int64_t, int64, u_.int64_ = _val;) -#endif -#undef SET - -#if PICOJSON_USE_RVALUE_REFERENCE -#define MOVESET(ctype, jtype, setter) \ - template <> inline void value::set(ctype && _val) { \ - clear(); \ - type_ = jtype##_type; \ - setter \ - } -MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) -MOVESET(array, array, u_.array_ = new array(std::move(_val));) -MOVESET(object, object, u_.object_ = new object(std::move(_val));) -#undef MOVESET -#endif - -inline bool value::evaluate_as_boolean() const { - switch (type_) { - case null_type: - return false; - case boolean_type: - return u_.boolean_; - case number_type: - return u_.number_ != 0; -#ifdef PICOJSON_USE_INT64 - case int64_type: - return u_.int64_ != 0; -#endif - case string_type: - return !u_.string_->empty(); - default: - return true; - } -} - -inline const value &value::get(const size_t idx) const { - static value s_null; - PICOJSON_ASSERT(is()); - return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; -} - -inline value &value::get(const size_t idx) { - static value s_null; - PICOJSON_ASSERT(is()); - return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; -} - -inline const value &value::get(const std::string &key) const { - static value s_null; - PICOJSON_ASSERT(is()); - object::const_iterator i = u_.object_->find(key); - return i != u_.object_->end() ? i->second : s_null; -} - -inline value &value::get(const std::string &key) { - static value s_null; - PICOJSON_ASSERT(is()); - object::iterator i = u_.object_->find(key); - return i != u_.object_->end() ? i->second : s_null; -} - -inline bool value::contains(const size_t idx) const { - PICOJSON_ASSERT(is()); - return idx < u_.array_->size(); -} - -inline bool value::contains(const std::string &key) const { - PICOJSON_ASSERT(is()); - object::const_iterator i = u_.object_->find(key); - return i != u_.object_->end(); -} - -inline std::string value::to_str() const { - switch (type_) { - case null_type: - return "null"; - case boolean_type: - return u_.boolean_ ? "true" : "false"; -#ifdef PICOJSON_USE_INT64 - case int64_type: { - char buf[sizeof("-9223372036854775808")]; - SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); - return buf; - } -#endif - case number_type: { - char buf[256]; - double tmp; - SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_); -#if PICOJSON_USE_LOCALE - char *decimal_point = localeconv()->decimal_point; - if (strcmp(decimal_point, ".") != 0) { - size_t decimal_point_len = strlen(decimal_point); - for (char *p = buf; *p != '\0'; ++p) { - if (strncmp(p, decimal_point, decimal_point_len) == 0) { - return std::string(buf, p) + "." + (p + decimal_point_len); - } - } - } -#endif - return buf; - } - case string_type: - return *u_.string_; - case array_type: - return "array"; - case object_type: - return "object"; - default: - PICOJSON_ASSERT(0); -#ifdef _MSC_VER - __assume(0); -#endif - } - return std::string(); -} - -template void copy(const std::string &s, Iter oi) { - std::copy(s.begin(), s.end(), oi); -} - -template struct serialize_str_char { - Iter oi; - void operator()(char c) { - switch (c) { -#define MAP(val, sym) \ - case val: \ - copy(sym, oi); \ - break - MAP('"', "\\\""); - MAP('\\', "\\\\"); - MAP('/', "\\/"); - MAP('\b', "\\b"); - MAP('\f', "\\f"); - MAP('\n', "\\n"); - MAP('\r', "\\r"); - MAP('\t', "\\t"); -#undef MAP - default: - if (static_cast(c) < 0x20 || c == 0x7f) { - char buf[7]; - SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); - copy(buf, buf + 6, oi); - } else { - *oi++ = c; - } - break; - } - } -}; - -template void serialize_str(const std::string &s, Iter oi) { - *oi++ = '"'; - serialize_str_char process_char = {oi}; - std::for_each(s.begin(), s.end(), process_char); - *oi++ = '"'; -} - -template void value::serialize(Iter oi, bool prettify) const { - return _serialize(oi, prettify ? 0 : -1); -} - -inline std::string value::serialize(bool prettify) const { - return _serialize(prettify ? 0 : -1); -} - -template void value::_indent(Iter oi, int indent) { - *oi++ = '\n'; - for (int i = 0; i < indent * INDENT_WIDTH; ++i) { - *oi++ = ' '; - } -} - -template void value::_serialize(Iter oi, int indent) const { - switch (type_) { - case string_type: - serialize_str(*u_.string_, oi); - break; - case array_type: { - *oi++ = '['; - if (indent != -1) { - ++indent; - } - for (array::const_iterator i = u_.array_->begin(); i != u_.array_->end(); ++i) { - if (i != u_.array_->begin()) { - *oi++ = ','; - } - if (indent != -1) { - _indent(oi, indent); - } - i->_serialize(oi, indent); - } - if (indent != -1) { - --indent; - if (!u_.array_->empty()) { - _indent(oi, indent); - } - } - *oi++ = ']'; - break; - } - case object_type: { - *oi++ = '{'; - if (indent != -1) { - ++indent; - } - for (object::const_iterator i = u_.object_->begin(); i != u_.object_->end(); ++i) { - if (i != u_.object_->begin()) { - *oi++ = ','; - } - if (indent != -1) { - _indent(oi, indent); - } - serialize_str(i->first, oi); - *oi++ = ':'; - if (indent != -1) { - *oi++ = ' '; - } - i->second._serialize(oi, indent); - } - if (indent != -1) { - --indent; - if (!u_.object_->empty()) { - _indent(oi, indent); - } - } - *oi++ = '}'; - break; - } - default: - copy(to_str(), oi); - break; - } - if (indent == 0) { - *oi++ = '\n'; - } -} - -inline std::string value::_serialize(int indent) const { - std::string s; - _serialize(std::back_inserter(s), indent); - return s; -} - -template class input { -protected: - Iter cur_, end_; - bool consumed_; - int line_; - -public: - input(const Iter &first, const Iter &last) : cur_(first), end_(last), consumed_(false), line_(1) { - } - int getc() { - if (consumed_) { - if (*cur_ == '\n') { - ++line_; - } - ++cur_; - } - if (cur_ == end_) { - consumed_ = false; - return -1; - } - consumed_ = true; - return *cur_ & 0xff; - } - void ungetc() { - consumed_ = false; - } - Iter cur() const { - if (consumed_) { - input *self = const_cast *>(this); - self->consumed_ = false; - ++self->cur_; - } - return cur_; - } - int line() const { - return line_; - } - void skip_ws() { - while (1) { - int ch = getc(); - if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { - ungetc(); - break; - } - } - } - bool expect(const int expected) { - skip_ws(); - if (getc() != expected) { - ungetc(); - return false; - } - return true; - } - bool match(const std::string &pattern) { - for (std::string::const_iterator pi(pattern.begin()); pi != pattern.end(); ++pi) { - if (getc() != *pi) { - ungetc(); - return false; - } - } - return true; - } -}; - -template inline int _parse_quadhex(input &in) { - int uni_ch = 0, hex; - for (int i = 0; i < 4; i++) { - if ((hex = in.getc()) == -1) { - return -1; - } - if ('0' <= hex && hex <= '9') { - hex -= '0'; - } else if ('A' <= hex && hex <= 'F') { - hex -= 'A' - 0xa; - } else if ('a' <= hex && hex <= 'f') { - hex -= 'a' - 0xa; - } else { - in.ungetc(); - return -1; - } - uni_ch = uni_ch * 16 + hex; - } - return uni_ch; -} - -template inline bool _parse_codepoint(String &out, input &in) { - int uni_ch; - if ((uni_ch = _parse_quadhex(in)) == -1) { - return false; - } - if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { - if (0xdc00 <= uni_ch) { - // a second 16-bit of a surrogate pair appeared - return false; - } - // first 16-bit of surrogate pair, get the next one - if (in.getc() != '\\' || in.getc() != 'u') { - in.ungetc(); - return false; - } - int second = _parse_quadhex(in); - if (!(0xdc00 <= second && second <= 0xdfff)) { - return false; - } - uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); - uni_ch += 0x10000; - } - if (uni_ch < 0x80) { - out.push_back(static_cast(uni_ch)); - } else { - if (uni_ch < 0x800) { - out.push_back(static_cast(0xc0 | (uni_ch >> 6))); - } else { - if (uni_ch < 0x10000) { - out.push_back(static_cast(0xe0 | (uni_ch >> 12))); - } else { - out.push_back(static_cast(0xf0 | (uni_ch >> 18))); - out.push_back(static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); - } - out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); - } - out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); - } - return true; -} - -template inline bool _parse_string(String &out, input &in) { - while (1) { - int ch = in.getc(); - if (ch < ' ') { - in.ungetc(); - return false; - } else if (ch == '"') { - return true; - } else if (ch == '\\') { - if ((ch = in.getc()) == -1) { - return false; - } - switch (ch) { -#define MAP(sym, val) \ - case sym: \ - out.push_back(val); \ - break - MAP('"', '\"'); - MAP('\\', '\\'); - MAP('/', '/'); - MAP('b', '\b'); - MAP('f', '\f'); - MAP('n', '\n'); - MAP('r', '\r'); - MAP('t', '\t'); -#undef MAP - case 'u': - if (!_parse_codepoint(out, in)) { - return false; - } - break; - default: - return false; - } - } else { - out.push_back(static_cast(ch)); - } - } - return false; -} - -template inline bool _parse_array(Context &ctx, input &in) { - if (!ctx.parse_array_start()) { - return false; - } - size_t idx = 0; - if (in.expect(']')) { - return ctx.parse_array_stop(idx); - } - do { - if (!ctx.parse_array_item(in, idx)) { - return false; - } - idx++; - } while (in.expect(',')); - return in.expect(']') && ctx.parse_array_stop(idx); -} - -template inline bool _parse_object(Context &ctx, input &in) { - if (!ctx.parse_object_start()) { - return false; - } - if (in.expect('}')) { - return ctx.parse_object_stop(); - } - do { - std::string key; - if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { - return false; - } - if (!ctx.parse_object_item(in, key)) { - return false; - } - } while (in.expect(',')); - return in.expect('}') && ctx.parse_object_stop(); -} - -template inline std::string _parse_number(input &in) { - std::string num_str; - while (1) { - int ch = in.getc(); - if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || ch == 'E') { - num_str.push_back(static_cast(ch)); - } else if (ch == '.') { -#if PICOJSON_USE_LOCALE - num_str += localeconv()->decimal_point; -#else - num_str.push_back('.'); -#endif - } else { - in.ungetc(); - break; - } - } - return num_str; -} - -template inline bool _parse(Context &ctx, input &in) { - in.skip_ws(); - int ch = in.getc(); - switch (ch) { -#define IS(ch, text, op) \ - case ch: \ - if (in.match(text) && op) { \ - return true; \ - } else { \ - return false; \ - } - IS('n', "ull", ctx.set_null()); - IS('f', "alse", ctx.set_bool(false)); - IS('t', "rue", ctx.set_bool(true)); -#undef IS - case '"': - return ctx.parse_string(in); - case '[': - return _parse_array(ctx, in); - case '{': - return _parse_object(ctx, in); - default: - if (('0' <= ch && ch <= '9') || ch == '-') { - double f; - char *endp; - in.ungetc(); - std::string num_str(_parse_number(in)); - if (num_str.empty()) { - return false; - } -#ifdef PICOJSON_USE_INT64 - { - errno = 0; - intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); - if (errno == 0 && std::numeric_limits::min() <= ival && ival <= std::numeric_limits::max() && - endp == num_str.c_str() + num_str.size()) { - ctx.set_int64(ival); - return true; - } - } -#endif - f = strtod(num_str.c_str(), &endp); - if (endp == num_str.c_str() + num_str.size()) { - ctx.set_number(f); - return true; - } - return false; - } - break; - } - in.ungetc(); - return false; -} - -class deny_parse_context { -public: - bool set_null() { - return false; - } - bool set_bool(bool) { - return false; - } -#ifdef PICOJSON_USE_INT64 - bool set_int64(int64_t) { - return false; - } -#endif - bool set_number(double) { - return false; - } - template bool parse_string(input &) { - return false; - } - bool parse_array_start() { - return false; - } - template bool parse_array_item(input &, size_t) { - return false; - } - bool parse_array_stop(size_t) { - return false; - } - bool parse_object_start() { - return false; - } - template bool parse_object_item(input &, const std::string &) { - return false; - } -}; - -class default_parse_context { -protected: - value *out_; - size_t depths_; - -public: - default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) : out_(out), depths_(depths) { - } - bool set_null() { - *out_ = value(); - return true; - } - bool set_bool(bool b) { - *out_ = value(b); - return true; - } -#ifdef PICOJSON_USE_INT64 - bool set_int64(int64_t i) { - *out_ = value(i); - return true; - } -#endif - bool set_number(double f) { - *out_ = value(f); - return true; - } - template bool parse_string(input &in) { - *out_ = value(string_type, false); - return _parse_string(out_->get(), in); - } - bool parse_array_start() { - if (depths_ == 0) - return false; - --depths_; - *out_ = value(array_type, false); - return true; - } - template bool parse_array_item(input &in, size_t) { - array &a = out_->get(); - a.push_back(value()); - default_parse_context ctx(&a.back(), depths_); - return _parse(ctx, in); - } - bool parse_array_stop(size_t) { - ++depths_; - return true; - } - bool parse_object_start() { - if (depths_ == 0) - return false; - *out_ = value(object_type, false); - return true; - } - template bool parse_object_item(input &in, const std::string &key) { - object &o = out_->get(); - default_parse_context ctx(&o[key], depths_); - return _parse(ctx, in); - } - bool parse_object_stop() { - ++depths_; - return true; - } - -private: - default_parse_context(const default_parse_context &); - default_parse_context &operator=(const default_parse_context &); -}; - -class null_parse_context { -protected: - size_t depths_; - -public: - struct dummy_str { - void push_back(int) { - } - }; - -public: - null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) { - } - bool set_null() { - return true; - } - bool set_bool(bool) { - return true; - } -#ifdef PICOJSON_USE_INT64 - bool set_int64(int64_t) { - return true; - } -#endif - bool set_number(double) { - return true; - } - template bool parse_string(input &in) { - dummy_str s; - return _parse_string(s, in); - } - bool parse_array_start() { - if (depths_ == 0) - return false; - --depths_; - return true; - } - template bool parse_array_item(input &in, size_t) { - return _parse(*this, in); - } - bool parse_array_stop(size_t) { - ++depths_; - return true; - } - bool parse_object_start() { - if (depths_ == 0) - return false; - --depths_; - return true; - } - template bool parse_object_item(input &in, const std::string &) { - ++depths_; - return _parse(*this, in); - } - bool parse_object_stop() { - return true; - } - -private: - null_parse_context(const null_parse_context &); - null_parse_context &operator=(const null_parse_context &); -}; - -// obsolete, use the version below -template inline std::string parse(value &out, Iter &pos, const Iter &last) { - std::string err; - pos = parse(out, pos, last, &err); - return err; -} - -template inline Iter _parse(Context &ctx, const Iter &first, const Iter &last, std::string *err) { - input in(first, last); - if (!_parse(ctx, in) && err != NULL) { - char buf[64]; - SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); - *err = buf; - while (1) { - int ch = in.getc(); - if (ch == -1 || ch == '\n') { - break; - } else if (ch >= ' ') { - err->push_back(static_cast(ch)); - } - } - } - return in.cur(); -} - -template inline Iter parse(value &out, const Iter &first, const Iter &last, std::string *err) { - default_parse_context ctx(&out); - return _parse(ctx, first, last, err); -} - -inline std::string parse(value &out, const std::string &s) { - std::string err; - parse(out, s.begin(), s.end(), &err); - return err; -} - -inline std::string parse(value &out, std::istream &is) { - std::string err; - parse(out, std::istreambuf_iterator(is.rdbuf()), std::istreambuf_iterator(), &err); - return err; -} - -template struct last_error_t { static std::string s; }; -template std::string last_error_t::s; - -inline void set_last_error(const std::string &s) { - last_error_t::s = s; -} - -inline const std::string &get_last_error() { - return last_error_t::s; -} - -inline bool operator==(const value &x, const value &y) { - if (x.is()) - return y.is(); -#define PICOJSON_CMP(type) \ - if (x.is()) \ - return y.is() && x.get() == y.get() - PICOJSON_CMP(bool); - PICOJSON_CMP(double); - PICOJSON_CMP(std::string); - PICOJSON_CMP(array); - PICOJSON_CMP(object); -#undef PICOJSON_CMP - PICOJSON_ASSERT(0); -#ifdef _MSC_VER - __assume(0); -#endif - return false; -} - -inline bool operator!=(const value &x, const value &y) { - return !(x == y); -} -} - -#if !PICOJSON_USE_RVALUE_REFERENCE -namespace std { -template <> inline void swap(picojson::value &x, picojson::value &y) { - x.swap(y); -} -} -#endif - -inline std::istream &operator>>(std::istream &is, picojson::value &x) { - picojson::set_last_error(std::string()); - const std::string err(picojson::parse(x, is)); - if (!err.empty()) { - picojson::set_last_error(err); - is.setstate(std::ios::failbit); - } - return is; -} - -inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { - x.serialize(std::ostream_iterator(os)); - return os; -} -#ifdef _MSC_VER -#pragma warning(pop) -#endif - -#endif \ No newline at end of file From 2f8aad95e030d31b3f457eadef1580948a6839f6 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 8 Feb 2023 15:52:31 +0800 Subject: [PATCH 11/50] Update .mergify.yml --- .mergify.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.mergify.yml b/.mergify.yml index 5cb1f486..0f182b51 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -136,7 +136,7 @@ pull_request_rules: add: ["Docker"] - name: "auto add label=Deployment" conditions: - - files~=^speechx/ + - files~=^runtime/ actions: label: add: ["Deployment"] From 78e29c8ec4c12173357acf5e89281b742c444408 Mon Sep 17 00:00:00 2001 From: masimeng1994 <1057010459@qq.com> Date: Mon, 20 Feb 2023 19:59:57 +0800 Subject: [PATCH 12/50] add cls engine (#2923) --- runtime/CMakeLists.txt | 2 +- runtime/cmake/fastdeploy.cmake | 39 +++ runtime/engine/CMakeLists.txt | 1 + runtime/engine/cls/CMakeLists.txt | 7 + runtime/engine/cls/nnet/CMakeLists.txt | 8 + runtime/engine/cls/nnet/panns_interface.cc | 78 +++++ runtime/engine/cls/nnet/panns_interface.h | 27 ++ runtime/engine/cls/nnet/panns_nnet.cc | 228 +++++++++++++ runtime/engine/cls/nnet/panns_nnet.h | 74 ++++ runtime/engine/cls/nnet/panns_nnet_main.cc | 49 +++ runtime/engine/common/base/config.h | 338 +++++++++++++++++++ runtime/engine/common/utils/CMakeLists.txt | 1 + runtime/engine/common/utils/audio_process.cc | 83 +++++ runtime/engine/common/utils/audio_process.h | 32 ++ 14 files changed, 966 insertions(+), 1 deletion(-) create mode 100644 runtime/cmake/fastdeploy.cmake create mode 100644 runtime/engine/cls/CMakeLists.txt create mode 100644 runtime/engine/cls/nnet/CMakeLists.txt create mode 100644 runtime/engine/cls/nnet/panns_interface.cc create mode 100644 runtime/engine/cls/nnet/panns_interface.h create mode 100644 runtime/engine/cls/nnet/panns_nnet.cc create mode 100644 runtime/engine/cls/nnet/panns_nnet.h create mode 100644 runtime/engine/cls/nnet/panns_nnet_main.cc create mode 100644 runtime/engine/common/base/config.h create mode 100644 runtime/engine/common/utils/audio_process.cc create mode 100644 runtime/engine/common/utils/audio_process.h diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 8bd3f28c..44ee3a58 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -139,7 +139,7 @@ out=':'.join([libs_dir, fluid_dir]); print(out); \ OUTPUT_VARIABLE PADDLE_LIB_DIRS) message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) - +add_compile_options(-fPIC) ############################################################################### # Add local library ############################################################################### diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake new file mode 100644 index 00000000..773414c1 --- /dev/null +++ b/runtime/cmake/fastdeploy.cmake @@ -0,0 +1,39 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +set(ARCH "mserver_x86_64" CACHE STRING "Target Architecture: +android_arm, android_armv7, android_armv8, android_x86, android_x86_64, +mserver_x86_64, ubuntu_x86_64, ios_armv7, ios_armv7s, ios_armv8, ios_x86_64, ios_x86, +windows_x86") + +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy) +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz) + exec_program("mkdir -p ${FASTDEPLOY_DIR} && + wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.2.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2 ${FASTDEPLOY_DIR}/linux-x64") +endif() + +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz) + exec_program("mkdir -p ${FASTDEPLOY_DIR} && + wget https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.0-shared.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared ${FASTDEPLOY_DIR}/android-armv7v8") +endif() + +if (ARCH STREQUAL "mserver_x86_64") + set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/linux-x64) + add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") + # add_definitions("-DUSE_ORT_BACKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") +elseif (ARCH STREQUAL "android_armv7") + set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") +endif() + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) +include_directories(${FASTDEPLOY_INCS}) \ No newline at end of file diff --git a/runtime/engine/CMakeLists.txt b/runtime/engine/CMakeLists.txt index b522e158..42399fe9 100644 --- a/runtime/engine/CMakeLists.txt +++ b/runtime/engine/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(asr) add_subdirectory(common) add_subdirectory(kaldi) add_subdirectory(codelab) +add_subdirectory(cls) \ No newline at end of file diff --git a/runtime/engine/cls/CMakeLists.txt b/runtime/engine/cls/CMakeLists.txt new file mode 100644 index 00000000..4d5e0cff --- /dev/null +++ b/runtime/engine/cls/CMakeLists.txt @@ -0,0 +1,7 @@ +project(cls) + +include(fastdeploy) +# add_definitions("-DTEST_DEBUG") +# add_definitions("-DPRINT_TIME") + +add_subdirectory(nnet) \ No newline at end of file diff --git a/runtime/engine/cls/nnet/CMakeLists.txt b/runtime/engine/cls/nnet/CMakeLists.txt new file mode 100644 index 00000000..b4b76120 --- /dev/null +++ b/runtime/engine/cls/nnet/CMakeLists.txt @@ -0,0 +1,8 @@ +set(srcs panns_nnet.cc panns_interface.cc) + +add_library(cls SHARED ${srcs}) +target_link_libraries(cls -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils) + +set(bin_name panns_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} -static-libstdc++;-Wl,-Bsymbolic cls gflags glog) \ No newline at end of file diff --git a/runtime/engine/cls/nnet/panns_interface.cc b/runtime/engine/cls/nnet/panns_interface.cc new file mode 100644 index 00000000..257ee44f --- /dev/null +++ b/runtime/engine/cls/nnet/panns_interface.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cls/nnet/panns_interface.h" +#include "cls/nnet/panns_nnet.h" +#include "common/base/config.h" + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path) { + Config conf(conf_path); + // cls init + ppspeech::ClsNnetConf cls_nnet_conf; + cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true); + cls_nnet_conf.wav_normal_type_ = + conf.Read("wav_normal_type", std::string("linear")); + cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0); + cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string("")); + cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string("")); + cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string("")); + cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12); + cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000); + cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32); + cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10); + cls_nnet_conf.num_bins = conf.Read("num_bins", 64); + cls_nnet_conf.low_freq = conf.Read("low_freq", 50); + cls_nnet_conf.high_freq = conf.Read("high_freq", 14000); + cls_nnet_conf.dither = conf.Read("dither", 0.0); + + ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet(); + int ret = cls_model->Init(cls_nnet_conf); + return static_cast(cls_model); +} + +int ClsDestroyInstance(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model != NULL) { + delete cls_model; + cls_model = NULL; + } + return 0; +} + +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + int ret = cls_model->Forward(wav_path, topk, result, result_max_len); + return 0; +} + +int ClsReset(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + cls_model->Reset(); + return 0; +} +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/cls/nnet/panns_interface.h b/runtime/engine/cls/nnet/panns_interface.h new file mode 100644 index 00000000..0d1ce95f --- /dev/null +++ b/runtime/engine/cls/nnet/panns_interface.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path); +int ClsDestroyInstance(void* instance); +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len); +int ClsReset(void* instance); +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/cls/nnet/panns_nnet.cc b/runtime/engine/cls/nnet/panns_nnet.cc new file mode 100644 index 00000000..6b8213f6 --- /dev/null +++ b/runtime/engine/cls/nnet/panns_nnet.cc @@ -0,0 +1,228 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cls/nnet/panns_nnet.h" +#ifdef PRINT_TIME +#include "kaldi/base/timer.h" +#endif + +namespace ppspeech { + +ClsNnet::ClsNnet() { + // wav_reader_ = NULL; + runtime_ = NULL; +} + +void ClsNnet::Reset() { + // wav_reader_->Clear(); + ss_.str(""); +} + +int ClsNnet::Init(const ClsNnetConf& conf) { + conf_ = conf; + // init fbank opts + fbank_opts_.frame_opts.samp_freq = conf.samp_freq; + fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms; + fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms; + fbank_opts_.mel_opts.num_bins = conf.num_bins; + fbank_opts_.mel_opts.low_freq = conf.low_freq; + fbank_opts_.mel_opts.high_freq = conf.high_freq; + fbank_opts_.frame_opts.dither = conf.dither; + fbank_opts_.use_log_fbank = false; + + // init dict + if (conf.dict_file_path_ != "") { + ReadFileToVector(conf.dict_file_path_, &dict_); + } + + // init model + fastdeploy::RuntimeOption runtime_option; + +#ifdef USE_ORT_BACKEND + runtime_option.SetModelPath( + conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx + runtime_option.UseOrtBackend(); // onnx +#endif +#ifdef USE_PADDLE_LITE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UseLiteBackend(); +#endif +#ifdef USE_PADDLE_INFERENCE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UsePaddleInferBackend(); +#endif + runtime_option.SetCpuThreadNum(conf.num_cpu_thread_); + runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass"); + runtime_ = std::unique_ptr(new fastdeploy::Runtime()); + if (!runtime_->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + } + + Reset(); + return 0; +} + +int ClsNnet::Forward(const char* wav_path, + int topk, + char* result, + int result_max_len) { +#ifdef PRINT_TIME + kaldi::Timer timer; + timer.Reset(); +#endif + // read wav + std::ifstream infile(wav_path, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 this_channel = 0; + kaldi::Matrix wavform_kaldi = wave_data.Data(); + // only get channel 0 + int wavform_len = wavform_kaldi.NumCols(); + std::vector wavform(wavform_kaldi.Data(), + wavform_kaldi.Data() + wavform_len); + WaveformFloatNormal(&wavform); + WaveformNormal(&wavform, + conf_.wav_normal_, + conf_.wav_normal_type_, + conf_.wav_norm_mul_factor_); +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.wavform", std::ios::out); + for (int i = 0; i < wavform.size(); ++i) { + fp << std::setprecision(18) << wavform[i] << " "; + } + fp << "\n"; + } +#endif +#ifdef PRINT_TIME + printf("wav read consume: %fs\n", timer.Elapsed()); +#endif + +#ifdef PRINT_TIME + timer.Reset(); +#endif + + std::vector feats; + std::unique_ptr data_source( + new ppspeech::DataCache()); + ppspeech::Fbank fbank(fbank_opts_, std::move(data_source)); + fbank.Accept(wavform); + fbank.SetFinished(); + fbank.Read(&feats); + + int feat_dim = fbank_opts_.mel_opts.num_bins; + int num_frames = feats.size() / feat_dim; + + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]); + } + } +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.feat", std::ios::out); + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + fp << std::setprecision(18) << feats[i * feat_dim + j] << " "; + } + fp << "\n"; + } + } +#endif +#ifdef PRINT_TIME + printf("extract fbank consume: %fs\n", timer.Elapsed()); +#endif + + // infer + std::vector model_out; +#ifdef PRINT_TIME + timer.Reset(); +#endif + ModelForward(feats.data(), num_frames, feat_dim, &model_out); +#ifdef PRINT_TIME + printf("fast deploy infer consume: %fs\n", timer.Elapsed()); +#endif +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.logits", std::ios::out); + for (int i = 0; i < model_out.size(); ++i) { + fp << std::setprecision(18) << model_out[i] << "\n"; + } + } +#endif + + // construct result str + ss_ << "{"; + GetTopkResult(topk, model_out); + ss_ << "}"; + + if (result_max_len <= ss_.str().size()) { + printf("result_max_len is short than result len\n"); + } + snprintf(result, result_max_len, "%s", ss_.str().c_str()); + return 0; +} + +int ClsNnet::ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out) { + // init input tensor shape + fastdeploy::TensorInfo info = runtime_->GetInputInfo(0); + info.shape = {1, num_frames, feat_dim}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + input_tensors[0].SetExternalData({1, num_frames, feat_dim}, + fastdeploy::FDDataType::FP32, + static_cast(features)); + + // get input name + input_tensors[0].name = info.name; + + runtime_->Infer(input_tensors, &output_tensors); + + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out->resize(output_shape[0] * output_shape[1]); + memcpy(static_cast(model_out->data()), + output_tensors[0].Data(), + output_shape[0] * output_shape[1] * sizeof(float)); + return 0; +} + +int ClsNnet::GetTopkResult(int k, const std::vector& model_out) { + std::vector values; + std::vector indics; + TopK(model_out, k, &values, &indics); + for (int i = 0; i < k; ++i) { + if (i != 0) { + ss_ << ","; + } + ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\""; + } + return 0; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/cls/nnet/panns_nnet.h b/runtime/engine/cls/nnet/panns_nnet.h new file mode 100644 index 00000000..3a4a5718 --- /dev/null +++ b/runtime/engine/cls/nnet/panns_nnet.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/frontend/data_cache.h" +#include "common/frontend/fbank.h" +#include "common/frontend/feature-fbank.h" +#include "common/frontend/frontend_itf.h" +#include "common/frontend/wave-reader.h" +#include "common/utils/audio_process.h" +#include "common/utils/file_utils.h" +#include "fastdeploy/runtime.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +namespace ppspeech { +struct ClsNnetConf { + // wav + bool wav_normal_; + std::string wav_normal_type_; + float wav_norm_mul_factor_; + // model + std::string model_file_path_; + std::string param_file_path_; + std::string dict_file_path_; + int num_cpu_thread_; + // fbank + float samp_freq; + float frame_length_ms; + float frame_shift_ms; + int num_bins; + float low_freq; + float high_freq; + float dither; +}; + +class ClsNnet { + public: + ClsNnet(); + int Init(const ClsNnetConf& conf); + int Forward(const char* wav_path, + int topk, + char* result, + int result_max_len); + void Reset(); + + private: + int ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out); + int ModelForwardStream(std::vector* feats); + int GetTopkResult(int k, const std::vector& model_out); + + ClsNnetConf conf_; + knf::FbankOptions fbank_opts_; + std::unique_ptr runtime_; + std::vector dict_; + std::stringstream ss_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/cls/nnet/panns_nnet_main.cc b/runtime/engine/cls/nnet/panns_nnet_main.cc new file mode 100644 index 00000000..4280d14c --- /dev/null +++ b/runtime/engine/cls/nnet/panns_nnet_main.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "base/flags.h" +#include "cls/nnet/panns_interface.h" + +DEFINE_string(conf_path, "", "config path"); +DEFINE_string(scp_path, "", "wav scp path"); +DEFINE_string(topk, "", "print topk results"); + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + CHECK_GT(FLAGS_conf_path.size(), 0); + CHECK_GT(FLAGS_scp_path.size(), 0); + CHECK_GT(FLAGS_topk.size(), 0); + void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str()); + int ret = 0; + // read wav + std::ifstream ifs(FLAGS_scp_path); + std::string line = ""; + int topk = std::atoi(FLAGS_topk.c_str()); + while (getline(ifs, line)) { + // read wav + char result[1024] = {0}; + ret = ppspeech::ClsFeedForward( + instance, line.c_str(), topk, result, 1024); + printf("%s %s\n", line.c_str(), result); + ret = ppspeech::ClsReset(instance); + } + ret = ppspeech::ClsDestroyInstance(instance); + return 0; +} diff --git a/runtime/engine/common/base/config.h b/runtime/engine/common/base/config.h new file mode 100644 index 00000000..c59c3ab8 --- /dev/null +++ b/runtime/engine/common/base/config.h @@ -0,0 +1,338 @@ +// Copyright (c) code is from +// https://blog.csdn.net/huixingshao/article/details/45969887. + +#include +#include +#include +#include +#include +using namespace std; + +#pragma once + +#pragma region ParseIniFile +/* +* \brief Generic configuration Class +* +*/ +class Config { + // Data + protected: + std::string m_Delimiter; //!< separator between key and value + std::string m_Comment; //!< separator between value and comments + std::map + m_Contents; //!< extracted keys and values + + typedef std::map::iterator mapi; + typedef std::map::const_iterator mapci; + // Methods + public: + Config(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + Config(); + template + T Read(const std::string& in_key) const; //! + template + T Read(const std::string& in_key, const T& in_value) const; + template + bool ReadInto(T* out_var, const std::string& in_key) const; + template + bool ReadInto(T* out_var, + const std::string& in_key, + const T& in_value) const; + bool FileExist(std::string filename); + void ReadFile(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + + // Check whether key exists in configuration + bool KeyExists(const std::string& in_key) const; + + // Modify keys and values + template + void Add(const std::string& in_key, const T& in_value); + void Remove(const std::string& in_key); + + // Check or change configuration syntax + std::string GetDelimiter() const { return m_Delimiter; } + std::string GetComment() const { return m_Comment; } + std::string SetDelimiter(const std::string& in_s) { + std::string old = m_Delimiter; + m_Delimiter = in_s; + return old; + } + std::string SetComment(const std::string& in_s) { + std::string old = m_Comment; + m_Comment = in_s; + return old; + } + + // Write or read configuration + friend std::ostream& operator<<(std::ostream& os, const Config& cf); + friend std::istream& operator>>(std::istream& is, Config& cf); + + protected: + template + static std::string T_as_string(const T& t); + template + static T string_as_T(const std::string& s); + static void Trim(std::string* inout_s); + + + // Exception types + public: + struct File_not_found { + std::string filename; + explicit File_not_found(const std::string& filename_ = std::string()) + : filename(filename_) {} + }; + struct Key_not_found { // thrown only by T read(key) variant of read() + std::string key; + explicit Key_not_found(const std::string& key_ = std::string()) + : key(key_) {} + }; +}; + +/* static */ +template +std::string Config::T_as_string(const T& t) { + // Convert from a T to a string + // Type T must support << operator + std::ostringstream ost; + ost << t; + return ost.str(); +} + + +/* static */ +template +T Config::string_as_T(const std::string& s) { + // Convert from a string to a T + // Type T must support >> operator + T t; + std::istringstream ist(s); + ist >> t; + return t; +} + + +/* static */ +template <> +inline std::string Config::string_as_T(const std::string& s) { + // Convert from a string to a string + // In other words, do nothing + return s; +} + + +/* static */ +template <> +inline bool Config::string_as_T(const std::string& s) { + // Convert from a string to a bool + // Interpret "false", "F", "no", "n", "0" as false + // Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true + bool b = true; + std::string sup = s; + for (std::string::iterator p = sup.begin(); p != sup.end(); ++p) + *p = toupper(*p); // make string all caps + if (sup == std::string("FALSE") || sup == std::string("F") || + sup == std::string("NO") || sup == std::string("N") || + sup == std::string("0") || sup == std::string("NONE")) + b = false; + return b; +} + + +template +T Config::Read(const std::string& key) const { + // Read the value corresponding to key + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) throw Key_not_found(key); + return string_as_T(p->second); +} + + +template +T Config::Read(const std::string& key, const T& value) const { + // Return the value corresponding to key or given default value + // if key is not found + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) { + printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str()); + return value; + } else { + printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str()); + return string_as_T(p->second); + } +} + + +template +bool Config::ReadInto(T* var, const std::string& key) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise leave var untouched + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) *var = string_as_T(p->second); + return found; +} + + +template +bool Config::ReadInto(T* var, const std::string& key, const T& value) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise set var to given default + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) + *var = string_as_T(p->second); + else + var = value; + return found; +} + + +template +void Config::Add(const std::string& in_key, const T& value) { + // Add a key with given value + std::string v = T_as_string(value); + std::string key = in_key; + Trim(&key); + Trim(&v); + m_Contents[key] = v; + return; +} + +Config::Config(string filename, string delimiter, string comment) + : m_Delimiter(delimiter), m_Comment(comment) { + // Construct a Config, getting keys and values from given file + + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + + +Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) { + // Construct a Config without a file; empty +} + + +bool Config::KeyExists(const string& key) const { + // Indicate whether key is found + mapci p = m_Contents.find(key); + return (p != m_Contents.end()); +} + + +/* static */ +void Config::Trim(string* inout_s) { + // Remove leading and trailing whitespace + static const char whitespace[] = " \n\t\v\r\f"; + inout_s->erase(0, inout_s->find_first_not_of(whitespace)); + inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U); +} + + +std::ostream& operator<<(std::ostream& os, const Config& cf) { + // Save a Config to os + for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end(); + ++p) { + os << p->first << " " << cf.m_Delimiter << " "; + os << p->second << std::endl; + } + return os; +} + +void Config::Remove(const string& key) { + // Remove key and its value + m_Contents.erase(m_Contents.find(key)); + return; +} + +std::istream& operator>>(std::istream& is, Config& cf) { + // Load a Config from is + // Read in keys and values, keeping internal whitespace + typedef string::size_type pos; + const string& delim = cf.m_Delimiter; // separator + const string& comm = cf.m_Comment; // comment + const pos skip = delim.length(); // length of separator + + string nextline = ""; // might need to read ahead to see where value ends + + while (is || nextline.length() > 0) { + // Read an entire line at a time + string line; + if (nextline.length() > 0) { + line = nextline; // we read ahead; use it now + nextline = ""; + } else { + std::getline(is, line); + } + + // Ignore comments + line = line.substr(0, line.find(comm)); + + // Parse the line if it contains a delimiter + pos delimPos = line.find(delim); + if (delimPos < string::npos) { + // Extract the key + string key = line.substr(0, delimPos); + line.replace(0, delimPos + skip, ""); + + // See if value continues on the next line + // Stop at blank line, next line with a key, end of stream, + // or end of file sentry + bool terminate = false; + while (!terminate && is) { + std::getline(is, nextline); + terminate = true; + + string nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy == "") continue; + + nextline = nextline.substr(0, nextline.find(comm)); + if (nextline.find(delim) != string::npos) continue; + + nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy != "") line += "\n"; + line += nextline; + terminate = false; + } + + // Store key and value + Config::Trim(&key); + Config::Trim(&line); + cf.m_Contents[key] = line; // overwrites if key is repeated + } + } + + return is; +} +bool Config::FileExist(std::string filename) { + bool exist = false; + std::ifstream in(filename.c_str()); + if (in) exist = true; + return exist; +} + +void Config::ReadFile(string filename, string delimiter, string comment) { + m_Delimiter = delimiter; + m_Comment = comment; + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + +#pragma endregion ParseIniFIle diff --git a/runtime/engine/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt index c47b25c0..8589b19a 100644 --- a/runtime/engine/common/utils/CMakeLists.txt +++ b/runtime/engine/common/utils/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(utils file_utils.cc math.cc strings.cc + audio_process.cc ) diff --git a/runtime/engine/common/utils/audio_process.cc b/runtime/engine/common/utils/audio_process.cc new file mode 100644 index 00000000..54540b85 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/audio_process.h" + +namespace ppspeech{ + +int WaveformFloatNormal(std::vector* waveform) { + int tot_samples = waveform->size(); + for (int i = 0; i < tot_samples; i++) { + (*waveform)[i] = (*waveform)[i] / 32768.0; + } + return 0; +} + +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor) { + if (wav_normal == false) { + return 0; + } + if (wav_normal_type == "linear") { + float amax = INT32_MIN; + for (int i = 0; i < waveform->size(); ++i) { + float tmp = std::abs((*waveform)[i]); + amax = std::max(amax, tmp); + } + float factor = 1.0 / (amax + 1e-8); + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = (*waveform)[i] * factor * wav_norm_mul_factor; + } + } else if (wav_normal_type == "gaussian") { + double sum = std::accumulate(waveform->begin(), waveform->end(), 0.0); + double mean = sum / waveform->size(); //均值 + + double accum = 0.0; + std::for_each(waveform->begin(), waveform->end(), [&](const double d) { + accum += (d - mean) * (d - mean); + }); + + double stdev = sqrt(accum / (waveform->size() - 1)); //方差 + stdev = std::max(stdev, 1e-8); + + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = + wav_norm_mul_factor * ((*waveform)[i] - mean) / stdev; + } + } else { + printf("don't support\n"); + return -1; + } + return 0; +} + +float PowerTodb(float in, float ref_value, float amin, float top_db) { + if (amin <= 0) { + printf("amin must be strictly positive\n"); + return -1; + } + + if (ref_value <= 0) { + printf("ref_value must be strictly positive\n"); + return -1; + } + + float out = 10.0 * log10(std::max(amin, in)); + out -= 10.0 * log10(std::max(ref_value, amin)); + return out; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/audio_process.h b/runtime/engine/common/utils/audio_process.h new file mode 100644 index 00000000..164d4c07 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +namespace ppspeech{ +int WaveformFloatNormal(std::vector* waveform); +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor); +float PowerTodb(float in, + float ref_value = 1.0, + float amin = 1e-10, + float top_db = 80.0); +} // namespace ppspeech \ No newline at end of file From b35fc01a3a451b34655e23b9db183fb529c983e2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 28 Feb 2023 10:38:28 +0800 Subject: [PATCH 13/50] opt to compile asr,cls,vad; add vad; format code (#2968) --- runtime/.gitignore | 3 + runtime/CMakeLists.txt | 177 +- runtime/build.sh | 2 +- runtime/cmake/fastdeploy.cmake | 15 +- runtime/engine/CMakeLists.txt | 19 +- .../engine/asr/recognizer/u2_recognizer.cc | 3 +- runtime/engine/common/CMakeLists.txt | 2 +- runtime/engine/common/base/CMakeLists.txt | 20 + .../common/base/{flags.h => flags.h.in} | 2 +- .../engine/common/base/{log.h => log.h.in} | 2 +- runtime/engine/common/frontend/cmvn.cc | 2 +- .../engine/common/frontend/feature-fbank.h | 1 + .../engine/common/frontend/feature-window.cc | 1 + runtime/engine/common/frontend/rfft.cc | 4 +- .../engine/common/matrix/kaldi-matrix-inl.h | 37 +- runtime/engine/common/matrix/kaldi-matrix.cc | 1396 ++++++++------- runtime/engine/common/matrix/kaldi-matrix.h | 1443 ++++++++------- .../engine/common/matrix/kaldi-vector-inl.h | 39 +- runtime/engine/common/matrix/kaldi-vector.cc | 1592 +++++++++-------- runtime/engine/common/matrix/kaldi-vector.h | 533 +++--- runtime/engine/common/matrix/matrix-common.h | 72 +- runtime/engine/kaldi/CMakeLists.txt | 15 +- runtime/engine/kaldi/base/kaldi-types.h | 12 + runtime/engine/vad/CMakeLists.txt | 18 + runtime/engine/vad/README.md | 121 ++ runtime/engine/vad/README_CN.md | 119 ++ runtime/engine/vad/infer_onnx_silero_vad.cc | 65 + runtime/engine/vad/vad.cc | 306 ++++ runtime/engine/vad/vad.h | 124 ++ runtime/engine/vad/wav.h | 197 ++ 30 files changed, 3811 insertions(+), 2531 deletions(-) create mode 100644 runtime/engine/common/base/CMakeLists.txt rename runtime/engine/common/base/{flags.h => flags.h.in} (96%) rename runtime/engine/common/base/{log.h => log.h.in} (96%) create mode 100644 runtime/engine/vad/CMakeLists.txt create mode 100644 runtime/engine/vad/README.md create mode 100644 runtime/engine/vad/README_CN.md create mode 100644 runtime/engine/vad/infer_onnx_silero_vad.cc create mode 100644 runtime/engine/vad/vad.cc create mode 100644 runtime/engine/vad/vad.h create mode 100644 runtime/engine/vad/wav.h diff --git a/runtime/.gitignore b/runtime/.gitignore index 0783b138..9aa98ef7 100644 --- a/runtime/.gitignore +++ b/runtime/.gitignore @@ -1,3 +1,6 @@ +engine/common/base/flags.h +engine/common/base/log.h + tools/valgrind* *log fc_patch/* diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 44ee3a58..015a1088 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -20,8 +20,7 @@ project(paddlespeech VERSION 0.1) set(CMAKE_VERBOSE_MAKEFILE on) -# set std-14 -set(CMAKE_CXX_STANDARD 14) + include(FetchContent) include(ExternalProject) @@ -31,15 +30,28 @@ set(FETCHCONTENT_QUIET off) get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") set(FETCHCONTENT_BASE_DIR ${fc_patch}) +set(CMAKE_CXX_FLAGS) +set(CMAKE_CXX_FLAGS_DEBUG) +set(CMAKE_CXX_FLAGS_RELEASE) + +# set std-14 +set(CMAKE_CXX_STANDARD 14) + # compiler option # Keep the same with openfst, -fPIC or -fpic set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") + +add_compile_options(-fPIC) ############################################################################### # Option Configurations ############################################################################### +option(WITH_ASR "build asr" ON) +option(WITH_CLS "build cls" ON) +option(WITH_VAD "build vad" ON) + option(TEST_DEBUG "option for debug" OFF) option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) @@ -47,102 +59,117 @@ option(WITH_TESTING "unit test" ON) option(USING_GPU "u2 compute on GPU." OFF) ############################################################################### -# Include third party +# Include Third Party ############################################################################### include(gflags) include(glog) -# openfst -include(openfst) -add_dependencies(openfst gflags glog) - -# paddle lib -include(paddleinference) - # gtest if(WITH_TESTING) include(gtest) # download, build, install gtest endif() + +# fastdeploy +include(fastdeploy) + +if(WITH_ASR) + # openfst + include(openfst) + add_dependencies(openfst gflags glog) +endif() + +############################################################################### +# Find Package +############################################################################### + # python/pybind11/threads find_package(Threads REQUIRED) # https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 find_package(Python3 COMPONENTS Interpreter Development) find_package(pybind11 CONFIG) -if(Python3_FOUND) - message(STATUS "Python3_FOUND = ${Python3_FOUND}") - message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") - message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") - message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") - message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") - set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) - set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) -endif() - -message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") -message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}") -if(pybind11_FOUND) - message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") - message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") - message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") +if(WITH_ASR) + if(Python3_FOUND) + message(STATUS "Python3_FOUND = ${Python3_FOUND}") + message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") + message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") + message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") + message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") + set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) + set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) + endif() + + message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") + message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}") + + if(pybind11_FOUND) + message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") + message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") + message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") + endif() + + + # paddle libpaddle.so + # paddle include and link option + # -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so + execute_process( + COMMAND python -c "\ + import os;\ + import paddle;\ + include_dir=paddle.sysconfig.get_include();\ + paddle_dir=os.path.split(include_dir)[0];\ + libs_dir=os.path.join(paddle_dir, 'libs');\ + fluid_dir=os.path.join(paddle_dir, 'fluid');\ + out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\ + out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\ + " + OUTPUT_VARIABLE PADDLE_LINK_FLAGS + RESULT_VARIABLE SUCESS) + + message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS}) + string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS) + + # paddle compile option + # -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include + execute_process( + COMMAND python -c "\ + import paddle; \ + include_dir = paddle.sysconfig.get_include(); \ + print(f\"-I{include_dir}\"); \ + " + OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS) + message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS}) + string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS) + + + # for LD_LIBRARY_PATH + # set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/) + execute_process( + COMMAND python -c "\ + import os; \ + import paddle; \ + include_dir=paddle.sysconfig.get_include(); \ + paddle_dir=os.path.split(include_dir)[0]; \ + libs_dir=os.path.join(paddle_dir, 'libs'); \ + fluid_dir=os.path.join(paddle_dir, 'fluid'); \ + out=':'.join([libs_dir, fluid_dir]); print(out); \ + " + OUTPUT_VARIABLE PADDLE_LIB_DIRS) + message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) endif() -# paddle libpaddle.so -# paddle include and link option -# -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so -execute_process( - COMMAND python -c "\ -import os;\ -import paddle;\ -include_dir=paddle.sysconfig.get_include();\ -paddle_dir=os.path.split(include_dir)[0];\ -libs_dir=os.path.join(paddle_dir, 'libs');\ -fluid_dir=os.path.join(paddle_dir, 'fluid');\ -out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\ -out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\ - " - OUTPUT_VARIABLE PADDLE_LINK_FLAGS - RESULT_VARIABLE SUCESS) - -message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS}) -string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS) - -# paddle compile option -# -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include -execute_process( - COMMAND python -c "\ -import paddle; \ -include_dir = paddle.sysconfig.get_include(); \ -print(f\"-I{include_dir}\"); \ - " - OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS) -message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS}) -string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS) - - -# for LD_LIBRARY_PATH -# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/) -execute_process( - COMMAND python -c "\ -import os; \ -import paddle; \ -include_dir=paddle.sysconfig.get_include(); \ -paddle_dir=os.path.split(include_dir)[0]; \ -libs_dir=os.path.join(paddle_dir, 'libs'); \ -fluid_dir=os.path.join(paddle_dir, 'fluid'); \ -out=':'.join([libs_dir, fluid_dir]); print(out); \ - " - OUTPUT_VARIABLE PADDLE_LIB_DIRS) -message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) - -add_compile_options(-fPIC) ############################################################################### # Add local library ############################################################################### set(ENGINE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/engine) +message(STATUS "CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}") +message(STATUS "CMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}") + + add_subdirectory(engine) diff --git a/runtime/build.sh b/runtime/build.sh index 94d250f5..131fb7f1 100755 --- a/runtime/build.sh +++ b/runtime/build.sh @@ -4,5 +4,5 @@ set -xe # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -cmake -B build +cmake -B build -DWITH_ASR=OFF -DWITH_CLS=OFF cmake --build build -j diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake index 773414c1..cb9ceacd 100644 --- a/runtime/cmake/fastdeploy.cmake +++ b/runtime/cmake/fastdeploy.cmake @@ -8,11 +8,11 @@ windows_x86") set(CMAKE_VERBOSE_MAKEFILE ON) set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy) -if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz) +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz) exec_program("mkdir -p ${FASTDEPLOY_DIR} && - wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.2.tgz -P ${FASTDEPLOY_DIR} && - tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz -C ${FASTDEPLOY_DIR} && - mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2 ${FASTDEPLOY_DIR}/linux-x64") + wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4 ${FASTDEPLOY_DIR}/linux-x64") endif() if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz) @@ -36,4 +36,9 @@ elseif (ARCH STREQUAL "android_armv7") endif() include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) -include_directories(${FASTDEPLOY_INCS}) \ No newline at end of file + +# fix compiler flags conflict, since fastdeploy using c++11 for project +set(CMAKE_CXX_STANDARD 14) + +include_directories(${FASTDEPLOY_INCS}) +message(STATUS "FASTDEPLOY_INCS=${FASTDEPLOY_INCS}") \ No newline at end of file diff --git a/runtime/engine/CMakeLists.txt b/runtime/engine/CMakeLists.txt index 42399fe9..242a579b 100644 --- a/runtime/engine/CMakeLists.txt +++ b/runtime/engine/CMakeLists.txt @@ -6,8 +6,19 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common) -add_subdirectory(asr) -add_subdirectory(common) add_subdirectory(kaldi) -add_subdirectory(codelab) -add_subdirectory(cls) \ No newline at end of file +add_subdirectory(common) + +if(WITH_ASR) + add_subdirectory(asr) +endif() + +if(WITH_CLS) + add_subdirectory(cls) +endif() + +if(WITH_VAD) + add_subdirectory(vad) +endif() + +add_subdirectory(codelab) \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/u2_recognizer.cc b/runtime/engine/asr/recognizer/u2_recognizer.cc index da1348f5..36fecb0a 100644 --- a/runtime/engine/asr/recognizer/u2_recognizer.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer.cc @@ -38,7 +38,8 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) decoder_ = std::make_unique( resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts); } else { - decoder_ = std::make_unique(resource.decoder_opts.tlg_decoder_opts); + decoder_ = std::make_unique( + resource.decoder_opts.tlg_decoder_opts); } symbol_table_ = decoder_->WordSymbolTable(); diff --git a/runtime/engine/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt index 5e0a7d57..4f399eea 100644 --- a/runtime/engine/common/CMakeLists.txt +++ b/runtime/engine/common/CMakeLists.txt @@ -3,7 +3,7 @@ ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../ ) add_subdirectory(utils) - +add_subdirectory(base) add_subdirectory(matrix) include_directories( diff --git a/runtime/engine/common/base/CMakeLists.txt b/runtime/engine/common/base/CMakeLists.txt new file mode 100644 index 00000000..ab710874 --- /dev/null +++ b/runtime/engine/common/base/CMakeLists.txt @@ -0,0 +1,20 @@ +if(WITH_ASR) + add_compile_options(-DWITH_ASR) + set(PPS_FLAGS_LIB "fst/flags.h") + set(PPS_GLOB_LIB "fst/log.h") +else() + set(PPS_FLAGS_LIB "gflags/gflags.h") + set(PPS_GLOB_LIB "glog/logging.h") +endif() + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h") + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/log.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h") \ No newline at end of file diff --git a/runtime/engine/common/base/flags.h b/runtime/engine/common/base/flags.h.in similarity index 96% rename from runtime/engine/common/base/flags.h rename to runtime/engine/common/base/flags.h.in index 41df0d45..161366e8 100644 --- a/runtime/engine/common/base/flags.h +++ b/runtime/engine/common/base/flags.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/flags.h" +#include "@PPS_FLAGS_LIB@" \ No newline at end of file diff --git a/runtime/engine/common/base/log.h b/runtime/engine/common/base/log.h.in similarity index 96% rename from runtime/engine/common/base/log.h rename to runtime/engine/common/base/log.h.in index c613b98c..0dd588bc 100644 --- a/runtime/engine/common/base/log.h +++ b/runtime/engine/common/base/log.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/log.h" +#include "@PPS_GLOB_LIB@" diff --git a/runtime/engine/common/frontend/cmvn.cc b/runtime/engine/common/frontend/cmvn.cc index 8375d3d1..0f110820 100644 --- a/runtime/engine/common/frontend/cmvn.cc +++ b/runtime/engine/common/frontend/cmvn.cc @@ -33,7 +33,7 @@ CMVN::CMVN(std::string cmvn_file, unique_ptr base_extractor) dim_ = mean_stats_.size() - 1; } -void CMVN::ReadCMVNFromJson(string cmvn_file) { +void CMVN::ReadCMVNFromJson(std::string cmvn_file) { std::string json_str = ppspeech::ReadFile2String(cmvn_file); picojson::value value; std::string err; diff --git a/runtime/engine/common/frontend/feature-fbank.h b/runtime/engine/common/frontend/feature-fbank.h index 30085245..3dab793f 100644 --- a/runtime/engine/common/frontend/feature-fbank.h +++ b/runtime/engine/common/frontend/feature-fbank.h @@ -21,6 +21,7 @@ #ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ #define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ +#include #include #include "frontend/feature-window.h" diff --git a/runtime/engine/common/frontend/feature-window.cc b/runtime/engine/common/frontend/feature-window.cc index 1c474ccb..43c736e0 100644 --- a/runtime/engine/common/frontend/feature-window.cc +++ b/runtime/engine/common/frontend/feature-window.cc @@ -7,6 +7,7 @@ #include "frontend/feature-window.h" #include +#include #include #ifndef M_2PI diff --git a/runtime/engine/common/frontend/rfft.cc b/runtime/engine/common/frontend/rfft.cc index f0a3ebc7..8cdb634f 100644 --- a/runtime/engine/common/frontend/rfft.cc +++ b/runtime/engine/common/frontend/rfft.cc @@ -17,12 +17,12 @@ */ #include "frontend/rfft.h" +#include "base/log.h" #include +#include #include -#include "base/log.h" - // see fftsg.c #ifdef __cplusplus extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w); diff --git a/runtime/engine/common/matrix/kaldi-matrix-inl.h b/runtime/engine/common/matrix/kaldi-matrix-inl.h index eafbc6fb..ed18859d 100644 --- a/runtime/engine/common/matrix/kaldi-matrix-inl.h +++ b/runtime/engine/common/matrix/kaldi-matrix-inl.h @@ -25,40 +25,41 @@ namespace kaldi { /// Empty constructor -template -Matrix::Matrix(): MatrixBase(NULL, 0, 0, 0) { } +template +Matrix::Matrix() : MatrixBase(NULL, 0, 0, 0) {} /* template<> template<> -void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra, const VectorBase &rb); +void MatrixBase::AddVecVec(const float alpha, const VectorBase +&ra, const VectorBase &rb); template<> template<> -void MatrixBase::AddVecVec(const double alpha, const VectorBase &ra, const VectorBase &rb); +void MatrixBase::AddVecVec(const double alpha, const VectorBase +&ra, const VectorBase &rb); */ -template -inline std::ostream & operator << (std::ostream & os, const MatrixBase & M) { - M.Write(os, false); - return os; +template +inline std::ostream& operator<<(std::ostream& os, const MatrixBase& M) { + M.Write(os, false); + return os; } -template -inline std::istream & operator >> (std::istream & is, Matrix & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, Matrix& M) { + M.Read(is, false); + return is; } -template -inline std::istream & operator >> (std::istream & is, MatrixBase & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, MatrixBase& M) { + M.Read(is, false); + return is; } -}// namespace kaldi +} // namespace kaldi #endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ - diff --git a/runtime/engine/common/matrix/kaldi-matrix.cc b/runtime/engine/common/matrix/kaldi-matrix.cc index e446a6bf..6f65fb0a 100644 --- a/runtime/engine/common/matrix/kaldi-matrix.cc +++ b/runtime/engine/common/matrix/kaldi-matrix.cc @@ -166,14 +166,19 @@ void MatrixBase::AddMatMat(const Real alpha, const MatrixBase& B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); if (num_rows_ == 0) return; cblas_Xgemm(alpha, transA, A.data_, A.num_rows_, A.num_cols_, A.stride_, - transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, stride_); + transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, +stride_); } @@ -191,7 +196,8 @@ void MatrixBase::SetMatMatDivMat(const MatrixBase& A, id = od * (o / i); /// o / i is either zero or "scale". } else { id = od; /// Just imagine the scale was 1.0. This is somehow true in - /// expectation; anyway, this case should basically never happen so it doesn't + /// expectation; anyway, this case should basically never happen so it +doesn't /// really matter. } (*this)(r, c) = id; @@ -200,25 +206,25 @@ void MatrixBase::SetMatMatDivMat(const MatrixBase& A, } */ -//template -//void MatrixBase::CopyLowerToUpper() { - //KALDI_ASSERT(num_rows_ == num_cols_); - //Real *data = data_; - //MatrixIndexT num_rows = num_rows_, stride = stride_; - //for (int32 i = 0; i < num_rows; i++) - //for (int32 j = 0; j < i; j++) - //data[j * stride + i ] = data[i * stride + j]; +// template +// void MatrixBase::CopyLowerToUpper() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[j * stride + i ] = data[i * stride + j]; //} -//template -//void MatrixBase::CopyUpperToLower() { - //KALDI_ASSERT(num_rows_ == num_cols_); - //Real *data = data_; - //MatrixIndexT num_rows = num_rows_, stride = stride_; - //for (int32 i = 0; i < num_rows; i++) - //for (int32 j = 0; j < i; j++) - //data[i * stride + j] = data[j * stride + i]; +// template +// void MatrixBase::CopyUpperToLower() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[i * stride + j] = data[j * stride + i]; //} /* @@ -263,10 +269,14 @@ void MatrixBase::AddMatSmat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); // We iterate over the columns of B. @@ -301,10 +311,14 @@ void MatrixBase::AddSmatMat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); MatrixIndexT Astride = A.stride_, Bstride = B.stride_, stride = this->stride_, @@ -342,7 +356,8 @@ void MatrixBase::AddSpSp(const Real alpha, const SpMatrix &A_in, // fully (to save work, we used the matrix constructor from SpMatrix). // CblasLeft means A is on the left: C <-- alpha A B + beta C if (sz == 0) return; - cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, stride_); + cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, +stride_); } template @@ -352,13 +367,15 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, if (transA == kNoTrans) { Scale(alpha + 1.0); } else { - KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self +(transposed): not symmetric."); Real *data = data_; if (alpha == 1.0) { // common case-- handle separately. for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real sum = *lower + *upper; *lower = *upper = sum; } @@ -368,7 +385,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real lower_tmp = *lower; *lower += alpha * *upper; *upper += alpha * lower_tmp; @@ -390,7 +408,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } else { KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); if (num_rows_ == 0) return; - for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += stride) + for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += +stride) cblas_Xaxpy(num_cols_, alpha, adata, aStride, data, 1); } } @@ -503,7 +522,8 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_kj = alpha * p.second; Real *this_col_j = this->Data() + j; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a colmun as matrices are stored in row major order. + // pass stride to write a colmun as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_kj, a_col_k, A.stride_, this_col_j, this->stride_); //for (MatrixIndexT i = 0; i < this_num_rows; ++i) @@ -529,10 +549,11 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_jk = alpha * p.second; const Real *a_col_k = A.Data() + k; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a column as matrices are stored in row major order. + // pass stride to write a column as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, this_col_j, this->stride_); - //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; } } @@ -586,7 +607,8 @@ void MatrixBase::AddDiagVecMat( Real *data = data_; const Real *Mdata = M.Data(), *vdata = v.Data(); if (num_rows_ == 0) return; - for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += M_row_stride, vdata++) + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += +M_row_stride, vdata++) cblas_Xaxpy(num_cols, alpha * *vdata, Mdata, M_col_stride, data, 1); } @@ -620,7 +642,8 @@ void MatrixBase::AddMatDiagVec( if (num_rows_ == 0) return; for (MatrixIndexT i = 0; i < num_rows; i++){ for(MatrixIndexT j = 0; j < num_cols; j ++ ){ - data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + j*M_col_stride]; + data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + +j*M_col_stride]; } } } @@ -655,8 +678,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KALDI_ASSERT(s != NULL && U_in != this && V_in != this); Matrix tmpU, tmpV; - if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in empty. - if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in empty. + if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in +empty. + if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in +empty. /// Impementation notes: /// Lapack works in column-order, therefore the dimensions of *this are @@ -690,8 +715,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KaldiBlasInt result; // query for work space - char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == "none." - char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == "none." + char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == +"none." + char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == +"none." clapack_Xgesvd(v_job, u_job, &M, &N, data_, &LDA, s->Data(), @@ -700,7 +727,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, &work_query, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); l_work = static_cast(work_query); Real *p_work; @@ -718,7 +746,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, p_work, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); if (result != 0) { KALDI_WARN << "CLAPACK sgesvd_ : some weird convergence not satisfied"; @@ -729,167 +758,166 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, #endif */ // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const MatrixBase & M, - MatrixTransposeType trans/*=kNoTrans*/) +template +Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans /*=kNoTrans*/) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); - } else { - Resize(M.num_cols_, M.num_rows_); - this->CopyFromMat(M, kTrans); - } + if (trans == kNoTrans) { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); + } else { + Resize(M.num_cols_, M.num_rows_); + this->CopyFromMat(M, kTrans); + } } // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const Matrix & M): - MatrixBase() { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); +template +Matrix::Matrix(const Matrix &M) : MatrixBase() { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); } /// Copy constructor from another type. -template -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.NumRows(), M.NumCols()); - this->CopyFromMat(M); - } else { - Resize(M.NumCols(), M.NumRows()); - this->CopyFromMat(M, kTrans); - } +template +template +Matrix::Matrix(const MatrixBase &M, MatrixTransposeType trans) + : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols()); + this->CopyFromMat(M); + } else { + Resize(M.NumCols(), M.NumRows()); + this->CopyFromMat(M, kTrans); + } } // Instantiate this constructor for float->double and double->float. -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); -template +template inline void Matrix::Init(const MatrixIndexT rows, const MatrixIndexT cols, const MatrixStrideType stride_type) { - if (rows * cols == 0) { - KALDI_ASSERT(rows == 0 && cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - this->data_ = NULL; - return; - } - KALDI_ASSERT(rows > 0 && cols > 0); - MatrixIndexT skip, stride; - size_t size; - void *data; // aligned memory block - void *temp; // memory block to be really freed - - // compute the size of skip and real cols - skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) - % (16 / sizeof(Real)); - stride = cols + skip; - size = static_cast(rows) * static_cast(stride) - * sizeof(Real); - - // allocate the memory and set the right dimensions and parameters - if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { - MatrixBase::data_ = static_cast (data); - MatrixBase::num_rows_ = rows; - MatrixBase::num_cols_ = cols; - MatrixBase::stride_ = (stride_type == kDefaultStride ? stride : cols); - } else { - throw std::bad_alloc(); - } + if (rows * cols == 0) { + KALDI_ASSERT(rows == 0 && cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + this->data_ = NULL; + return; + } + KALDI_ASSERT(rows > 0 && cols > 0); + MatrixIndexT skip, stride; + size_t size; + void *data; // aligned memory block + void *temp; // memory block to be really freed + + // compute the size of skip and real cols + skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) % + (16 / sizeof(Real)); + stride = cols + skip; + size = + static_cast(rows) * static_cast(stride) * sizeof(Real); + + // allocate the memory and set the right dimensions and parameters + if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { + MatrixBase::data_ = static_cast(data); + MatrixBase::num_rows_ = rows; + MatrixBase::num_cols_ = cols; + MatrixBase::stride_ = + (stride_type == kDefaultStride ? stride : cols); + } else { + throw std::bad_alloc(); + } } -template +template void Matrix::Resize(const MatrixIndexT rows, const MatrixIndexT cols, MatrixResizeType resize_type, MatrixStrideType stride_type) { - // the next block uses recursion to handle what we have to do if - // resize_type == kCopyData. - if (resize_type == kCopyData) { - if (this->data_ == NULL || rows == 0) resize_type = kSetZero; // nothing to copy. - else if (rows == this->num_rows_ && cols == this->num_cols_ && - (stride_type == kDefaultStride || this->stride_ == this->num_cols_)) { return; } // nothing to do. - else { - // set tmp to a matrix of the desired size; if new matrix - // is bigger in some dimension, zero it. - MatrixResizeType new_resize_type = - (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero : kUndefined; - Matrix tmp(rows, cols, new_resize_type, stride_type); - MatrixIndexT rows_min = std::min(rows, this->num_rows_), - cols_min = std::min(cols, this->num_cols_); - tmp.Range(0, rows_min, 0, cols_min). - CopyFromMat(this->Range(0, rows_min, 0, cols_min)); - tmp.Swap(this); - // and now let tmp go out of scope, deleting what was in *this. - return; + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || rows == 0) + resize_type = kSetZero; // nothing to copy. + else if (rows == this->num_rows_ && cols == this->num_cols_ && + (stride_type == kDefaultStride || + this->stride_ == this->num_cols_)) { + return; + } // nothing to do. + else { + // set tmp to a matrix of the desired size; if new matrix + // is bigger in some dimension, zero it. + MatrixResizeType new_resize_type = + (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero + : kUndefined; + Matrix tmp(rows, cols, new_resize_type, stride_type); + MatrixIndexT rows_min = std::min(rows, this->num_rows_), + cols_min = std::min(cols, this->num_cols_); + tmp.Range(0, rows_min, 0, cols_min) + .CopyFromMat(this->Range(0, rows_min, 0, cols_min)); + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } } - } - // At this point, resize_type == kSetZero or kUndefined. + // At this point, resize_type == kSetZero or kUndefined. - if (MatrixBase::data_ != NULL) { - if (rows == MatrixBase::num_rows_ - && cols == MatrixBase::num_cols_) { - if (resize_type == kSetZero) - this->SetZero(); - return; + if (MatrixBase::data_ != NULL) { + if (rows == MatrixBase::num_rows_ && + cols == MatrixBase::num_cols_) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else + Destroy(); } - else - Destroy(); - } - Init(rows, cols, stride_type); - if (resize_type == kSetZero) MatrixBase::SetZero(); + Init(rows, cols, stride_type); + if (resize_type == kSetZero) MatrixBase::SetZero(); } -template -template +template +template void MatrixBase::CopyFromMat(const MatrixBase &M, MatrixTransposeType Trans) { - if (sizeof(Real) == sizeof(OtherReal) && - static_cast(M.Data()) == - static_cast(this->Data())) { - // CopyFromMat called on same data. Nothing to do (except sanity checks). - KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && - M.NumCols() == NumCols() && M.Stride() == Stride()); - return; - } - if (Trans == kNoTrans) { - KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); - for (MatrixIndexT i = 0; i < num_rows_; i++) - (*this).Row(i).CopyFromVec(M.Row(i)); - } else { - KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); - int32 this_stride = stride_, other_stride = M.Stride(); - Real *this_data = data_; - const OtherReal *other_data = M.Data(); - for (MatrixIndexT i = 0; i < num_rows_; i++) - for (MatrixIndexT j = 0; j < num_cols_; j++) - this_data[i * this_stride + j] = other_data[j * other_stride + i]; - } + if (sizeof(Real) == sizeof(OtherReal) && + static_cast(M.Data()) == + static_cast(this->Data())) { + // CopyFromMat called on same data. Nothing to do (except sanity + // checks). + KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && + M.NumCols() == NumCols() && M.Stride() == Stride()); + return; + } + if (Trans == kNoTrans) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); + for (MatrixIndexT i = 0; i < num_rows_; i++) + (*this).Row(i).CopyFromVec(M.Row(i)); + } else { + KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); + int32 this_stride = stride_, other_stride = M.Stride(); + Real *this_data = data_; + const OtherReal *other_data = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + this_data[i * this_stride + j] = + other_data[j * other_stride + i]; + } } // template instantiations. -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); /* // Specialize the template for CopyFromSp for float, float. @@ -987,99 +1015,97 @@ void MatrixBase::CopyFromTp(const TpMatrix & M, MatrixTransposeType trans); */ -template +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - if (stride_ == num_cols_) { - // one big copy operation. - const Real *rv_data = rv.Data(); - std::memcpy(data_, rv_data, sizeof(Real)*num_rows_*num_cols_); - } else { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = rv_data[c]; + if (rv.Dim() == num_rows_ * num_cols_) { + if (stride_ == num_cols_) { + // one big copy operation. + const Real *rv_data = rv.Data(); + std::memcpy(data_, rv_data, sizeof(Real) * num_rows_ * num_cols_); + } else { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = rv_data[c]; + } + rv_data += num_cols_; + } } - rv_data += num_cols_; - } + } else if (rv.Dim() == num_cols_) { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) + std::memcpy(RowData(r), rv_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments"; } - } else if (rv.Dim() == num_cols_) { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) - std::memcpy(RowData(r), rv_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments"; - } } -template -template +template +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const OtherReal *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = static_cast(rv_data[c]); - } - rv_data += num_cols_; + if (rv.Dim() == num_rows_ * num_cols_) { + const OtherReal *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = static_cast(rv_data[c]); + } + rv_data += num_cols_; + } + } else if (rv.Dim() == num_cols_) { + const OtherReal *rv_data = rv.Data(); + Real *first_row_data = RowData(0); + for (MatrixIndexT c = 0; c < num_cols_; c++) + first_row_data[c] = rv_data[c]; + for (MatrixIndexT r = 1; r < num_rows_; r++) + std::memcpy(RowData(r), first_row_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments."; } - } else if (rv.Dim() == num_cols_) { - const OtherReal *rv_data = rv.Data(); - Real *first_row_data = RowData(0); - for (MatrixIndexT c = 0; c < num_cols_; c++) - first_row_data[c] = rv_data[c]; - for (MatrixIndexT r = 1; r < num_rows_; r++) - std::memcpy(RowData(r), first_row_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments."; - } } -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template +template void MatrixBase::CopyColsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; + if (rv.Dim() == num_rows_ * num_cols_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; - for (MatrixIndexT c = 0; c < num_cols_; c++) { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - m_inc_data[r * stride_] = v_inc_data[r]; - } - v_inc_data += num_rows_; - m_inc_data ++; - } - } else if (rv.Dim() == num_rows_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real value = *(v_inc_data++); - for (MatrixIndexT c = 0; c < num_cols_; c++) - m_inc_data[c] = value; - m_inc_data += stride_; + for (MatrixIndexT c = 0; c < num_cols_; c++) { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + m_inc_data[r * stride_] = v_inc_data[r]; + } + v_inc_data += num_rows_; + m_inc_data++; + } + } else if (rv.Dim() == num_rows_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real value = *(v_inc_data++); + for (MatrixIndexT c = 0; c < num_cols_; c++) m_inc_data[c] = value; + m_inc_data += stride_; + } + } else { + KALDI_ERR << "Wrong size of arguments."; } - } else { - KALDI_ERR << "Wrong size of arguments."; - } } -template -void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIndexT row) { - KALDI_ASSERT(rv.Dim() == num_cols_ && - static_cast(row) < - static_cast(num_rows_)); +template +void MatrixBase::CopyRowFromVec(const VectorBase &rv, + const MatrixIndexT row) { + KALDI_ASSERT(rv.Dim() == num_cols_ && + static_cast(row) < + static_cast(num_rows_)); - const Real *rv_data = rv.Data(); - Real *row_data = RowData(row); + const Real *rv_data = rv.Data(); + Real *row_data = RowData(row); - std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); + std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); } /* template @@ -1091,40 +1117,40 @@ void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { *my_data = *rv_data; }*/ -template +template void MatrixBase::CopyColFromVec(const VectorBase &rv, const MatrixIndexT col) { - KALDI_ASSERT(rv.Dim() == num_rows_ && - static_cast(col) < - static_cast(num_cols_)); + KALDI_ASSERT(rv.Dim() == num_rows_ && + static_cast(col) < + static_cast(num_cols_)); - const Real *rv_data = rv.Data(); - Real *col_data = data_ + col; + const Real *rv_data = rv.Data(); + Real *col_data = data_ + col; - for (MatrixIndexT r = 0; r < num_rows_; r++) - col_data[r * stride_] = rv_data[r]; + for (MatrixIndexT r = 0; r < num_rows_; r++) + col_data[r * stride_] = rv_data[r]; } - -template +template void Matrix::RemoveRow(MatrixIndexT i) { - KALDI_ASSERT(static_cast(i) < - static_cast(MatrixBase::num_rows_) - && "Access out of matrix"); - for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) - MatrixBase::Row(j-1).CopyFromVec( MatrixBase::Row(j)); - MatrixBase::num_rows_--; + KALDI_ASSERT( + static_cast(i) < + static_cast(MatrixBase::num_rows_) && + "Access out of matrix"); + for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) + MatrixBase::Row(j - 1).CopyFromVec(MatrixBase::Row(j)); + MatrixBase::num_rows_--; } -template +template void Matrix::Destroy() { - // we need to free the data block if it was defined - if (NULL != MatrixBase::data_) - KALDI_MEMALIGN_FREE( MatrixBase::data_); - MatrixBase::data_ = NULL; - MatrixBase::num_rows_ = MatrixBase::num_cols_ - = MatrixBase::stride_ = 0; + // we need to free the data block if it was defined + if (NULL != MatrixBase::data_) + KALDI_MEMALIGN_FREE(MatrixBase::data_); + MatrixBase::data_ = NULL; + MatrixBase::num_rows_ = MatrixBase::num_cols_ = + MatrixBase::stride_ = 0; } @@ -1248,7 +1274,8 @@ template void MatrixBase::GroupPnormDeriv(const MatrixBase &input, const MatrixBase &output, Real power) { - KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == this->NumRows()); + KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == +this->NumRows()); KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 && this->NumRows() == output.NumRows()); @@ -1320,22 +1347,22 @@ void MatrixBase::MulColsVec(const VectorBase &scale) { } */ -template +template void MatrixBase::SetZero() { - if (num_cols_ == stride_) - memset(data_, 0, sizeof(Real)*num_rows_*num_cols_); - else - for (MatrixIndexT row = 0; row < num_rows_; row++) - memset(data_ + row*stride_, 0, sizeof(Real)*num_cols_); + if (num_cols_ == stride_) + memset(data_, 0, sizeof(Real) * num_rows_ * num_cols_); + else + for (MatrixIndexT row = 0; row < num_rows_; row++) + memset(data_ + row * stride_, 0, sizeof(Real) * num_cols_); } -template +template void MatrixBase::Set(Real value) { - for (MatrixIndexT row = 0; row < num_rows_; row++) { - for (MatrixIndexT col = 0; col < num_cols_; col++) { - (*this)(row, col) = value; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + (*this)(row, col) = value; + } } - } } /* @@ -1355,7 +1382,8 @@ void MatrixBase::SetRandn() { for (MatrixIndexT col = 0; col < nc; col += 2) { kaldi::RandGauss2(row_data + col, row_data + col + 1, &rstate); } - if (nc != num_cols_) row_data[nc] = static_cast(kaldi::RandGauss(&rstate)); + if (nc != num_cols_) row_data[nc] = +static_cast(kaldi::RandGauss(&rstate)); } } @@ -1371,273 +1399,302 @@ void MatrixBase::SetRandUniform() { } */ -template +template void MatrixBase::Write(std::ostream &os, bool binary) const { - if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream: stream not good"; - } - if (binary) { // Use separate binary and text formats, - // since in binary mode we need to know if it's float or double. - std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - - WriteToken(os, binary, my_token); - { - int32 rows = this->num_rows_; // make the size 32-bit on disk. - int32 cols = this->num_cols_; - KALDI_ASSERT(this->num_rows_ == (MatrixIndexT) rows); - KALDI_ASSERT(this->num_cols_ == (MatrixIndexT) cols); - WriteBasicType(os, binary, rows); - WriteBasicType(os, binary, cols); - } - if (Stride() == NumCols()) - os.write(reinterpret_cast (Data()), sizeof(Real) - * static_cast(num_rows_) * static_cast(num_cols_)); - else - for (MatrixIndexT i = 0; i < num_rows_; i++) - os.write(reinterpret_cast (RowData(i)), sizeof(Real) - * num_cols_); if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream"; - } - } else { // text mode. - if (num_cols_ == 0) { - os << " [ ]\n"; - } else { - os << " ["; - for (MatrixIndexT i = 0; i < num_rows_; i++) { - os << "\n "; - for (MatrixIndexT j = 0; j < num_cols_; j++) - os << (*this)(i, j) << " "; - } - os << "]\n"; + KALDI_ERR << "Failed to write matrix to stream: stream not good"; + } + if (binary) { // Use separate binary and text formats, + // since in binary mode we need to know if it's float or double. + std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + + WriteToken(os, binary, my_token); + { + int32 rows = this->num_rows_; // make the size 32-bit on disk. + int32 cols = this->num_cols_; + KALDI_ASSERT(this->num_rows_ == (MatrixIndexT)rows); + KALDI_ASSERT(this->num_cols_ == (MatrixIndexT)cols); + WriteBasicType(os, binary, rows); + WriteBasicType(os, binary, cols); + } + if (Stride() == NumCols()) + os.write(reinterpret_cast(Data()), + sizeof(Real) * static_cast(num_rows_) * + static_cast(num_cols_)); + else + for (MatrixIndexT i = 0; i < num_rows_; i++) + os.write(reinterpret_cast(RowData(i)), + sizeof(Real) * num_cols_); + if (!os.good()) { + KALDI_ERR << "Failed to write matrix to stream"; + } + } else { // text mode. + if (num_cols_ == 0) { + os << " [ ]\n"; + } else { + os << " ["; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + os << "\n "; + for (MatrixIndexT j = 0; j < num_cols_; j++) + os << (*this)(i, j) << " "; + } + os << "]\n"; + } } - } } -template -void MatrixBase::Read(std::istream & is, bool binary) { - // In order to avoid rewriting this, we just declare a Matrix and - // use it to read the data, then copy. - Matrix tmp; - tmp.Read(is, binary); - if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { - KALDI_ERR << "MatrixBase::Read, size mismatch " - << NumRows() << " x " << NumCols() << " versus " - << tmp.NumRows() << " x " << tmp.NumCols(); - } - CopyFromMat(tmp); +template +void MatrixBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Matrix and + // use it to read the data, then copy. + Matrix tmp; + tmp.Read(is, binary); + if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { + KALDI_ERR << "MatrixBase::Read, size mismatch " << NumRows() + << " x " << NumCols() << " versus " << tmp.NumRows() << " x " + << tmp.NumCols(); + } + CopyFromMat(tmp); } -template -void Matrix::Read(std::istream & is, bool binary) { - // now assume add == false. - MatrixIndexT pos_at_start = is.tellg(); - std::ostringstream specific_error; - - if (binary) { // Read in binary mode. - int peekval = Peek(is, binary); - if (peekval == 'C') { - // This code enables us to read CompressedMatrix as a regular matrix. - //CompressedMatrix compressed_mat; - //compressed_mat.Read(is, binary); // at this point, add == false. - //this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); - //compressed_mat.CopyToMat(this); - return; - } - const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); - if (peekval == other_token_start) { // need to instantiate the other type to read it. - typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. - Matrix other(this->num_rows_, this->num_cols_); - other.Read(is, binary); // add is false at this point anyway. - this->Resize(other.NumRows(), other.NumCols()); - this->CopyFromMat(other); - return; - } - std::string token; - ReadToken(is, binary, &token); - if (token != my_token) { - if (token.length() > 20) token = token.substr(0, 17) + "..."; - specific_error << ": Expected token " << my_token << ", got " << token; - goto bad; - } - int32 rows, cols; - ReadBasicType(is, binary, &rows); // throws on error. - ReadBasicType(is, binary, &cols); // throws on error. - if ((MatrixIndexT)rows != this->num_rows_ || (MatrixIndexT)cols != this->num_cols_) { - this->Resize(rows, cols); - } - if (this->Stride() == this->NumCols() && rows*cols!=0) { - is.read(reinterpret_cast(this->Data()), - sizeof(Real)*rows*cols); - if (is.fail()) goto bad; - } else { - for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { - is.read(reinterpret_cast(this->RowData(i)), sizeof(Real)*cols); - if (is.fail()) goto bad; - } - } - if (is.eof()) return; - if (is.fail()) goto bad; - return; - } else { // Text mode. - std::string str; - is >> str; // get a token - if (is.fail()) { specific_error << ": Expected \"[\", got EOF"; goto bad; } - // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back compatibility. - // is >> str; // get #rows - // is >> str; // get #cols - // is >> str; // get "[" - // } - if (str == "[]") { Resize(0, 0); return; } // Be tolerant of variants. - else if (str != "[") { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << ": Expected \"[\", got \"" << str << '"'; - goto bad; - } - // At this point, we have read "[". - std::vector* > data; - std::vector *cur_row = new std::vector; - while (1) { - int i = is.peek(); - if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto cleanup; } - else if (static_cast(i) == ']') { // Finished reading matrix. - is.get(); // eat the "]". - i = is.peek(); - if (static_cast(i) == '\r') { - is.get(); - is.get(); // get \r\n (must eat what we wrote) - } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) - if (is.fail()) { - KALDI_WARN << "After end of matrix data, read error."; - // we got the data we needed, so just warn for this error. +template +void Matrix::Read(std::istream &is, bool binary) { + // now assume add == false. + MatrixIndexT pos_at_start = is.tellg(); + std::ostringstream specific_error; + + if (binary) { // Read in binary mode. + int peekval = Peek(is, binary); + if (peekval == 'C') { + // This code enables us to read CompressedMatrix as a regular + // matrix. + // CompressedMatrix compressed_mat; + // compressed_mat.Read(is, binary); // at this point, add == false. + // this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); + // compressed_mat.CopyToMat(this); + return; } - // Now process the data. - if (!cur_row->empty()) data.push_back(cur_row); - else delete(cur_row); - cur_row = NULL; - if (data.empty()) { this->Resize(0, 0); return; } - else { - int32 num_rows = data.size(), num_cols = data[0]->size(); - this->Resize(num_rows, num_cols); - for (int32 i = 0; i < num_rows; i++) { - if (static_cast(data[i]->size()) != num_cols) { - specific_error << "Matrix has inconsistent #cols: " << num_cols - << " vs." << data[i]->size() << " (processing row" - << i << ")"; - goto cleanup; + const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Matrix other(this->num_rows_, this->num_cols_); + other.Read(is, binary); // add is false at this point anyway. + this->Resize(other.NumRows(), other.NumCols()); + this->CopyFromMat(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; + } + int32 rows, cols; + ReadBasicType(is, binary, &rows); // throws on error. + ReadBasicType(is, binary, &cols); // throws on error. + if ((MatrixIndexT)rows != this->num_rows_ || + (MatrixIndexT)cols != this->num_cols_) { + this->Resize(rows, cols); + } + if (this->Stride() == this->NumCols() && rows * cols != 0) { + is.read(reinterpret_cast(this->Data()), + sizeof(Real) * rows * cols); + if (is.fail()) goto bad; + } else { + for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { + is.read(reinterpret_cast(this->RowData(i)), + sizeof(Real) * cols); + if (is.fail()) goto bad; } - for (int32 j = 0; j < num_cols; j++) - (*this)(i, j) = (*(data[i]))[j]; - delete data[i]; - data[i] = NULL; - } } + if (is.eof()) return; + if (is.fail()) goto bad; return; - } else if (static_cast(i) == '\n' || static_cast(i) == ';') { - // End of matrix row. - is.get(); - if (cur_row->size() != 0) { - data.push_back(cur_row); - cur_row = new std::vector; - cur_row->reserve(data.back()->size()); - } - } else if ( (i >= '0' && i <= '9') || i == '-' ) { // A number... - Real r; - is >> r; + } else { // Text mode. + std::string str; + is >> str; // get a token if (is.fail()) { - specific_error << "Stream failure/EOF while reading matrix data."; - goto cleanup; + specific_error << ": Expected \"[\", got EOF"; + goto bad; } - cur_row->push_back(r); - } else if (isspace(i)) { - is.get(); // eat the space and do nothing. - } else { // NaN or inf or error. - std::string str; - is >> str; - if (!KALDI_STRCASECMP(str.c_str(), "inf") || - !KALDI_STRCASECMP(str.c_str(), "infinity")) { - cur_row->push_back(std::numeric_limits::infinity()); - KALDI_WARN << "Reading infinite value into matrix."; - } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { - cur_row->push_back(std::numeric_limits::quiet_NaN()); - KALDI_WARN << "Reading NaN value into matrix."; - } else { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << "Expecting numeric matrix data, got " << str; - goto cleanup; + // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back + // compatibility. + // is >> str; // get #rows + // is >> str; // get #cols + // is >> str; // get "[" + // } + if (str == "[]") { + Resize(0, 0); + return; + } // Be tolerant of variants. + else if (str != "[") { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << ": Expected \"[\", got \"" << str << '"'; + goto bad; + } + // At this point, we have read "[". + std::vector *> data; + std::vector *cur_row = new std::vector; + while (1) { + int i = is.peek(); + if (i == -1) { + specific_error << "Got EOF while reading matrix data"; + goto cleanup; + } else if (static_cast(i) == + ']') { // Finished reading matrix. + is.get(); // eat the "]". + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of matrix data, read error."; + // we got the data we needed, so just warn for this error. + } + // Now process the data. + if (!cur_row->empty()) + data.push_back(cur_row); + else + delete (cur_row); + cur_row = NULL; + if (data.empty()) { + this->Resize(0, 0); + return; + } else { + int32 num_rows = data.size(), num_cols = data[0]->size(); + this->Resize(num_rows, num_cols); + for (int32 i = 0; i < num_rows; i++) { + if (static_cast(data[i]->size()) != num_cols) { + specific_error + << "Matrix has inconsistent #cols: " << num_cols + << " vs." << data[i]->size() + << " (processing row" << i << ")"; + goto cleanup; + } + for (int32 j = 0; j < num_cols; j++) + (*this)(i, j) = (*(data[i]))[j]; + delete data[i]; + data[i] = NULL; + } + } + return; + } else if (static_cast(i) == '\n' || + static_cast(i) == ';') { + // End of matrix row. + is.get(); + if (cur_row->size() != 0) { + data.push_back(cur_row); + cur_row = new std::vector; + cur_row->reserve(data.back()->size()); + } + } else if ((i >= '0' && i <= '9') || i == '-') { // A number... + Real r; + is >> r; + if (is.fail()) { + specific_error + << "Stream failure/EOF while reading matrix data."; + goto cleanup; + } + cur_row->push_back(r); + } else if (isspace(i)) { + is.get(); // eat the space and do nothing. + } else { // NaN or inf or error. + std::string str; + is >> str; + if (!KALDI_STRCASECMP(str.c_str(), "inf") || + !KALDI_STRCASECMP(str.c_str(), "infinity")) { + cur_row->push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into matrix."; + } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { + cur_row->push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into matrix."; + } else { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << "Expecting numeric matrix data, got " + << str; + goto cleanup; + } + } } - } - } // Note, we never leave the while () loop before this // line (we return from it.) - cleanup: // We only reach here in case of error in the while loop above. - if(cur_row != NULL) - delete cur_row; - for (size_t i = 0; i < data.size(); i++) - if(data[i] != NULL) - delete data[i]; - // and then go on to "bad" below, where we print error. - } + cleanup: // We only reach here in case of error in the while loop above. + if (cur_row != NULL) delete cur_row; + for (size_t i = 0; i < data.size(); i++) + if (data[i] != NULL) delete data[i]; + // and then go on to "bad" below, where we print error. + } bad: - KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() - << " File position at start is " - << pos_at_start << ", currently " << is.tellg(); + KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() + << " File position at start is " << pos_at_start << ", currently " + << is.tellg(); } // Constructor... note that this is not const-safe as it would // be quite complicated to implement a "const SubMatrix" class that // would not allow its contents to be changed. -template +template SubMatrix::SubMatrix(const MatrixBase &M, const MatrixIndexT ro, const MatrixIndexT r, const MatrixIndexT co, const MatrixIndexT c) { - if (r == 0 || c == 0) { - // we support the empty sub-matrix as a special case. - KALDI_ASSERT(c == 0 && r == 0); - this->data_ = NULL; - this->num_cols_ = 0; - this->num_rows_ = 0; - this->stride_ = 0; - return; - } - KALDI_ASSERT(static_cast(ro) < - static_cast(M.num_rows_) && - static_cast(co) < - static_cast(M.num_cols_) && - static_cast(r) <= - static_cast(M.num_rows_ - ro) && - static_cast(c) <= - static_cast(M.num_cols_ - co)); - // point to the begining of window - MatrixBase::num_rows_ = r; - MatrixBase::num_cols_ = c; - MatrixBase::stride_ = M.Stride(); - MatrixBase::data_ = M.Data_workaround() + - static_cast(co) + - static_cast(ro) * static_cast(M.Stride()); + if (r == 0 || c == 0) { + // we support the empty sub-matrix as a special case. + KALDI_ASSERT(c == 0 && r == 0); + this->data_ = NULL; + this->num_cols_ = 0; + this->num_rows_ = 0; + this->stride_ = 0; + return; + } + KALDI_ASSERT(static_cast(ro) < + static_cast(M.num_rows_) && + static_cast(co) < + static_cast(M.num_cols_) && + static_cast(r) <= + static_cast(M.num_rows_ - ro) && + static_cast(c) <= + static_cast(M.num_cols_ - co)); + // point to the begining of window + MatrixBase::num_rows_ = r; + MatrixBase::num_cols_ = c; + MatrixBase::stride_ = M.Stride(); + MatrixBase::data_ = + M.Data_workaround() + static_cast(co) + + static_cast(ro) * static_cast(M.Stride()); } -template +template SubMatrix::SubMatrix(Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, - MatrixIndexT stride): - MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! - if (data == NULL) { - KALDI_ASSERT(num_rows * num_cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - } else { - KALDI_ASSERT(this->stride_ >= this->num_cols_); - } + MatrixIndexT stride) + : MatrixBase( + data, num_cols, num_rows, stride) { // caution: reversed order! + if (data == NULL) { + KALDI_ASSERT(num_rows * num_cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + } else { + KALDI_ASSERT(this->stride_ >= this->num_cols_); + } } /* @@ -1665,9 +1722,11 @@ Real MatrixBase::Cond() const { KALDI_ASSERT(num_rows_ > 0&&num_cols_ > 0); Vector singular_values(std::min(num_rows_, num_cols_)); Svd(&singular_values); // Get singular values... - Real min = singular_values(0), max = singular_values(0); // both absolute values... + Real min = singular_values(0), max = singular_values(0); // both absolute +values... for (MatrixIndexT i = 1;i < singular_values.Dim();i++) { - min = std::min((Real)std::abs(singular_values(i)), min); max = std::max((Real)std::abs(singular_values(i)), max); + min = std::min((Real)std::abs(singular_values(i)), min); max = +std::max((Real)std::abs(singular_values(i)), max); } if (min > 0) return max/min; else return std::numeric_limits::infinity(); @@ -1677,7 +1736,8 @@ template Real MatrixBase::Trace(bool check_square) const { KALDI_ASSERT(!check_square || num_rows_ == num_cols_); Real ans = 0.0; - for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ [r + stride_*r]; + for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ +[r + stride_*r]; return ans; } @@ -1707,22 +1767,29 @@ Real MatrixBase::Min() const { template void MatrixBase::AddMatMatMat(Real alpha, - const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &A, +MatrixTransposeType transA, + const MatrixBase &B, +MatrixTransposeType transB, + const MatrixBase &C, +MatrixTransposeType transC, Real beta) { - // Note on time taken with different orders of computation. Assume not transposed in this / - // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and B.NumCols == C.NumRows, prefer + // Note on time taken with different orders of computation. Assume not +transposed in this / + // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and +B.NumCols == C.NumRows, prefer // rows where there is a choice. // time taken for (AB) is: A.NumRows*B.NumRows*C.Rows // time taken for (AB)C is A.NumRows*C.NumRows*C.Cols - // so this order is A.NumRows*B.NumRows*C.NumRows + A.NumRows*C.NumRows*C.NumCols. + // so this order is A.NumRows*B.NumRows*C.NumRows + +A.NumRows*C.NumRows*C.NumCols. // time taken for (BC) is: B.NumRows*C.NumRows*C.Cols // time taken for A(BC) is: A.NumRows*B.NumRows*C.Cols // so this order is B.NumRows*C.NumRows*C.NumCols + A.NumRows*B.NumRows*C.Cols - MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, BCols = B.num_cols_, + MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, +BCols = B.num_cols_, CRows = C.num_rows_, CCols = C.num_cols_; if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); @@ -1746,58 +1813,71 @@ void MatrixBase::AddMatMatMat(Real alpha, template -void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) { +void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) { // Svd, *this = U*diag(s)*Vt. // With (*this).num_rows_ == m, (*this).num_cols_ == n, - // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U and Vt mean + // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U +and Vt mean // we do not want that output. We expect that s.Dim() == m, // U is either 0 by 0 or m by n, and rv is either 0 by 0 or n by n. // Throws exception on error. - KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); // For compatibility with JAMA code. + KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); +// For compatibility with JAMA code. KALDI_ASSERT(s->Dim() == num_cols_); // s should be the smaller dim. - KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == num_cols_)); - KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == num_cols_)); + KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == +num_cols_)); + KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == +num_cols_)); Real prescale = 1.0; - if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause problems in Svd. + if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause +problems in Svd. Real max_elem = LargestAbsElem(); if (max_elem != 0) { prescale = 1.0 / max_elem; - if (std::abs(prescale) == std::numeric_limits::infinity()) { prescale = 1.0e+40; } + if (std::abs(prescale) == std::numeric_limits::infinity()) { +prescale = 1.0e+40; } (*this).Scale(prescale); } } #if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) - // "S" == skinny Svd (only one we support because of compatibility with Jama one which is only skinny), + // "S" == skinny Svd (only one we support because of compatibility with Jama +one which is only skinny), // "N"== no eigenvectors wanted. LapackGesvd(s, U, Vt); #else /* if (num_rows_ > 1 && num_cols_ > 1 && (*this)(0, 0) == (*this)(1, 1) - && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd sometimes crashes on. - KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to prevent crash."; + && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd +sometimes crashes on. + KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to +prevent crash."; for(int32 i = 0; i < NumRows(); i++) (*this)(i, i) *= 1.00001; }*/ // bool ans = JamaSvd(s, U, Vt); - //if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. - //if (!ans) { - //KALDI_ERR << "Error doing Svd"; // This one will be caught. - //} +// if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the +// transpose inside the JamaSvd routine. note, Vt is square. +// if (!ans) { +// KALDI_ERR << "Error doing Svd"; // This one will be caught. +//} //#endif - //if (prescale != 1.0) s->Scale(1.0/prescale); +// if (prescale != 1.0) s->Scale(1.0/prescale); //} /* template -void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) const { +void MatrixBase::Svd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) const { try { if (num_rows_ >= num_cols_) { Matrix tmp(*this); tmp.DestructiveSvd(s, U, Vt); } else { Matrix tmp(*this, kTrans); // transpose of *this. - // rVt will have different dim so cannot transpose in-place --> use a temp matrix. + // rVt will have different dim so cannot transpose in-place --> use a temp +matrix. Matrix Vt_Trans(Vt ? Vt->num_cols_ : 0, Vt ? Vt->num_rows_ : 0); // U will be transpose tmp.DestructiveSvd(s, Vt ? &Vt_Trans : NULL, U); @@ -1806,7 +1886,8 @@ void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase< } } catch (...) { KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" - << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), + << SubMatrix(*this, 0, std::min((MatrixIndexT)10, +num_rows_), 0, std::min((MatrixIndexT)10, num_cols_)) << ", min and max are: " << Min() << ", " << Max(); } @@ -1819,7 +1900,8 @@ bool MatrixBase::IsSymmetric(Real cutoff) const { Real bad_sum = 0.0, good_sum = 0.0; for (MatrixIndexT i = 0;i < R;i++) { for (MatrixIndexT j = 0;j < i;j++) { - Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); + Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = +0.5*(a-b); good_sum += std::abs(avg); bad_sum += std::abs(diff); } good_sum += std::abs((*this)(i, i)); @@ -1860,7 +1942,8 @@ bool MatrixBase::IsUnit(Real cutoff) const { Real bad_max = 0.0; for (MatrixIndexT i = 0; i < R;i++) for (MatrixIndexT j = 0; j < C;j++) - bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i == j?1.0:0.0)))); + bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i +== j?1.0:0.0)))); return (bad_max <= cutoff); } @@ -1880,7 +1963,8 @@ Real MatrixBase::FrobeniusNorm() const{ } template -bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) const { +bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) +const { if (num_rows_ != other.num_rows_ || num_cols_ != other.num_cols_) KALDI_ERR << "ApproxEqual: size mismatch."; Matrix tmp(*this); @@ -1953,27 +2037,35 @@ void MatrixBase::OrthogonalizeRows() { } -// Uses Svd to compute the eigenvalue decomposition of a symmetric positive semidefinite +// Uses Svd to compute the eigenvalue decomposition of a symmetric positive +semidefinite // matrix: -// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = rU^T. -// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U diag(rs) U^T. -// Throws exception if this failed to within supplied precision (typically because *this was not +// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = +rU^T. +// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U +diag(rs) U^T. +// Throws exception if this failed to within supplied precision (typically +because *this was not // symmetric positive definite). template -void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase *rU, Real check_thresh) // e.g. check_thresh = 0.001 +void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase +*rU, Real check_thresh) // e.g. check_thresh = 0.001 { const MatrixIndexT D = num_rows_; KALDI_ASSERT(num_rows_ == num_cols_); - KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be symmetrical."); + KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be +symmetrical."); KALDI_ASSERT(rU->num_rows_ == D && rU->num_cols_ == D && rs->Dim() == D); Matrix Vt(D, D); Svd(rs, rU, &Vt); - // First just zero any singular values if the column of U and V do not have +ve dot product-- - // this may mean we have small negative eigenvalues, and if we zero them the result will be closer to correct. + // First just zero any singular values if the column of U and V do not have ++ve dot product-- + // this may mean we have small negative eigenvalues, and if we zero them the +result will be closer to correct. for (MatrixIndexT i = 0;i < D;i++) { Real sum = 0.0; for (MatrixIndexT j = 0;j < D;j++) sum += (*rU)(j, i) * Vt(i, j); @@ -1992,9 +2084,12 @@ void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase * if (!(old_norm == 0 && new_norm == 0)) { float diff_norm = tmpThisFull.FrobeniusNorm(); - if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > old_norm*check_thresh) { - KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " !<< " - << check_thresh << "*" << old_norm << ", maybe matrix was not " + if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > +old_norm*check_thresh) { + KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " +!<< " + << check_thresh << "*" << old_norm << ", maybe matrix was not +" << "positive semi definite. Continuing anyway."; } } @@ -2006,7 +2101,8 @@ template Real MatrixBase::LogDet(Real *det_sign) const { Real log_det; Matrix tmp(*this); - tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves some computation). + tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves +some computation). return log_det; } @@ -2022,26 +2118,25 @@ void MatrixBase::InvertDouble(Real *log_det, Real *det_sign, } */ -//template -//void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { - //mat.CopyToMat(this); +// template +// void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { +// mat.CopyToMat(this); //} -//template -//Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { - //Resize(M.NumRows(), M.NumCols(), kUndefined); - //M.CopyToMat(this); +// template +// Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { +// Resize(M.NumRows(), M.NumCols(), kUndefined); +// M.CopyToMat(this); //} - -template +template void MatrixBase::InvertElements() { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - for (MatrixIndexT c = 0; c < num_cols_; c++) { - (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + for (MatrixIndexT c = 0; c < num_cols_; c++) { + (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + } } - } } /* template @@ -2108,7 +2203,8 @@ void MatrixBase::Pow(const MatrixBase &src, Real power) { } template -void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool include_sign) { +void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool +include_sign) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2117,9 +2213,9 @@ void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool incl row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col ++) { if (include_sign == true && src_row_data[col] < 0) { - row_data[col] = -pow(std::abs(src_row_data[col]), power); + row_data[col] = -pow(std::abs(src_row_data[col]), power); } else { - row_data[col] = pow(std::abs(src_row_data[col]), power); + row_data[col] = pow(std::abs(src_row_data[col]), power); } } } @@ -2134,7 +2230,8 @@ void MatrixBase::Floor(const MatrixBase &src, Real floor_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < floor_val ? floor_val : src_row_data[col]); + row_data[col] = (src_row_data[col] < floor_val ? floor_val : +src_row_data[col]); } } @@ -2147,7 +2244,8 @@ void MatrixBase::Ceiling(const MatrixBase &src, Real ceiling_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : src_row_data[col]); + row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : +src_row_data[col]); } } @@ -2173,12 +2271,14 @@ void MatrixBase::ExpSpecial(const MatrixBase &src) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < Real(0) ? kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); + row_data[col] = (src_row_data[col] < Real(0) ? +kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); } } template -void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit) { +void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, +Real upper_limit) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2188,11 +2288,11 @@ void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, for (MatrixIndexT col = 0; col < num_cols; col++) { const Real x = src_row_data[col]; if (!(x >= lower_limit)) - row_data[col] = kaldi::Exp(lower_limit); + row_data[col] = kaldi::Exp(lower_limit); else if (x > upper_limit) - row_data[col] = kaldi::Exp(upper_limit); + row_data[col] = kaldi::Exp(upper_limit); else - row_data[col] = kaldi::Exp(x); + row_data[col] = kaldi::Exp(x); } } } @@ -2220,12 +2320,12 @@ bool MatrixBase::Power(Real power) { return true; } */ -template +template void Matrix::Swap(Matrix *other) { - std::swap(this->data_, other->data_); - std::swap(this->num_cols_, other->num_cols_); - std::swap(this->num_rows_, other->num_rows_); - std::swap(this->stride_, other->stride_); + std::swap(this->data_, other->data_); + std::swap(this->num_cols_, other->num_cols_); + std::swap(this->num_rows_, other->num_rows_); + std::swap(this->stride_, other->stride_); } /* // Repeating this comment that appeared in the header: @@ -2238,12 +2338,14 @@ void Matrix::Swap(Matrix *other) { // be block diagonal, with 2x2 blocks corresponding to any such pairs. If a // pair is lambda +- i*mu, D will have a corresponding 2x2 block // [lambda, mu; -mu, lambda]. -// Note that if the input matrix (*this) is non-invertible, P may not be invertible +// Note that if the input matrix (*this) is non-invertible, P may not be +invertible // so in this case instead of the equation (*this) = P D P^{-1} holding, we have // instead (*this) P = P D. // // By making the pointer arguments non-NULL or NULL, the user can choose to take -// not to take the eigenvalues directly, and/or the matrix D which is block-diagonal +// not to take the eigenvalues directly, and/or the matrix D which is +block-diagonal // with 2x2 blocks. template void MatrixBase::Eig(MatrixBase *P, @@ -2369,7 +2471,8 @@ template bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); template -bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // header may be derived from a previous call to ReadHtk. Must be in binary mode. +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // +header may be derived from a previous call to ReadHtk. Must be in binary mode. { KALDI_ASSERT(M.NumRows() == static_cast(htk_hdr.mNSamples)); KALDI_ASSERT(M.NumCols() == static_cast(htk_hdr.mSampleSize) / @@ -2471,12 +2574,14 @@ template Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), CRows = C.NumRows(), CCols = C.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); - KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && +"TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, CRows*ACols)) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2508,13 +2613,16 @@ Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC, const MatrixBase &D, MatrixTransposeType transD) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), - CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = D.NumCols(); + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), + CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = +D.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); if (transD == kTrans) std::swap(DRows, DCols); - KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == DRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == +DRows && "TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, std::min(CRows*DCols, DRows*ACols))) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2541,13 +2649,18 @@ float TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &D, MatrixTransposeType transD); template -double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, - const MatrixBase &D, MatrixTransposeType transD); +double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType +transA, + const MatrixBase &B, MatrixTransposeType +transB, + const MatrixBase &C, MatrixTransposeType +transC, + const MatrixBase &D, MatrixTransposeType +transD); template void SortSvd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt, bool sort_on_absolute_value) { + MatrixBase *Vt, bool +sort_on_absolute_value) { /// Makes sure the Svd is sorted (from greatest to least absolute value). MatrixIndexT num_singval = s->Dim(); KALDI_ASSERT(U == NULL || U->NumCols() == num_singval); @@ -2589,7 +2702,8 @@ void SortSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt, bool); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D) { MatrixIndexT n = re.Dim(); KALDI_ASSERT(im.Dim() == n && D->NumRows() == n && D->NumCols() == n); @@ -2603,7 +2717,8 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } else { // First of a complex pair KALDI_ASSERT(j+1 < n && ApproxEqual(im(j+1), -im(j)) && ApproxEqual(re(j+1), re(j))); - /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // TEMP + /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // +TEMP Real lambda = re(j), mu = im(j); // create 2x2 block [lambda, mu; -mu, lambda] (*D)(j, j) = lambda; @@ -2616,10 +2731,12 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const +VectorBase &im, MatrixBase *D); @@ -2660,7 +2777,8 @@ bool AttemptComplexPower(double *x_re, double *x_im, double power); template Real TraceMatMat(const MatrixBase &A, const MatrixBase &B, - MatrixTransposeType trans) { // tr(A B), equivalent to sum of each element of A times same element in B' + MatrixTransposeType trans) { // tr(A B), equivalent to sum of +each element of A times same element in B' MatrixIndexT aStride = A.stride_, bStride = B.stride_; if (trans == kNoTrans) { KALDI_ASSERT(A.NumRows() == B.NumCols() && A.NumCols() == B.NumRows()); @@ -2791,29 +2909,32 @@ void MatrixBase::GroupMax(const MatrixBase &src) { } } */ -template +template void MatrixBase::CopyCols(const MatrixBase &src, const MatrixIndexT *indices) { - KALDI_ASSERT(NumRows() == src.NumRows()); - MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - this_stride = stride_, src_stride = src.stride_; - Real *this_data = this->data_; - const Real *src_data = src.data_; + KALDI_ASSERT(NumRows() == src.NumRows()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_, src_stride = src.stride_; + Real *this_data = this->data_; + const Real *src_data = src.data_; #ifdef KALDI_PARANOID - MatrixIndexT src_cols = src.NumCols(); - for (MatrixIndexT i = 0; i < num_cols; i++) - KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); + MatrixIndexT src_cols = src.NumCols(); + for (MatrixIndexT i = 0; i < num_cols; i++) + KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); #endif - // For the sake of memory locality we do this row by row, rather - // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { - const MatrixIndexT *index_ptr = &(indices[0]); - for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { - if (*index_ptr < 0) this_data[c] = 0; - else this_data[c] = src_data[*index_ptr]; + // For the sake of memory locality we do this row by row, rather + // than doing it column-wise using cublas_Xcopy + for (MatrixIndexT r = 0; r < num_rows; + r++, this_data += this_stride, src_data += src_stride) { + const MatrixIndexT *index_ptr = &(indices[0]); + for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { + if (*index_ptr < 0) + this_data[c] = 0; + else + this_data[c] = src_data[*index_ptr]; + } } - } } /* @@ -2833,7 +2954,8 @@ void MatrixBase::AddCols(const MatrixBase &src, // For the sake of memory locality we do this row by row, rather // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data ++= src_stride) { const MatrixIndexT *index_ptr = &(indices[0]); for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { if (*index_ptr >= 0) @@ -2965,7 +3087,8 @@ void MatrixBase::DiffSigmoid(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -2982,7 +3105,8 @@ void MatrixBase::DiffTanh(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -2997,7 +3121,8 @@ void MatrixBase::DiffTanh(const MatrixBase &value, /* template template -void MatrixBase::AddVecToRows(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToRows(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_cols); @@ -3028,7 +3153,8 @@ template void MatrixBase::AddVecToRows(const double alpha, template template -void MatrixBase::AddVecToCols(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToCols(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_rows); @@ -3058,10 +3184,10 @@ template void MatrixBase::AddVecToCols(const double alpha, template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); */ -//Explicit instantiation of the classes -//Apparently, it seems to be necessary that the instantiation -//happens at the end of the file. Otherwise, not all the member -//functions will get instantiated. +// Explicit instantiation of the classes +// Apparently, it seems to be necessary that the instantiation +// happens at the end of the file. Otherwise, not all the member +// functions will get instantiated. template class Matrix; template class Matrix; @@ -3070,4 +3196,4 @@ template class MatrixBase; template class SubMatrix; template class SubMatrix; -} // namespace kaldi +} // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-matrix.h b/runtime/engine/common/matrix/kaldi-matrix.h index 92274487..c082a731 100644 --- a/runtime/engine/common/matrix/kaldi-matrix.h +++ b/runtime/engine/common/matrix/kaldi-matrix.h @@ -38,669 +38,715 @@ namespace kaldi { /// Base class which provides matrix operations not involving resizing /// or allocation. Classes Matrix and SubMatrix inherit from it and take care /// of allocation and resizing. -template +template class MatrixBase { - public: - // so this child can access protected members of other instances. - friend class Matrix; - friend class SubMatrix; - // friend declarations for CUDA matrices (see ../cudamatrix/) - - /// Returns number of rows (or zero for empty matrix). - inline MatrixIndexT NumRows() const { return num_rows_; } - - /// Returns number of columns (or zero for empty matrix). - inline MatrixIndexT NumCols() const { return num_cols_; } - - /// Stride (distance in memory between each row). Will be >= NumCols. - inline MatrixIndexT Stride() const { return stride_; } - - /// Returns size in bytes of the data held by the matrix. - size_t SizeInBytes() const { - return static_cast(num_rows_) * static_cast(stride_) * - sizeof(Real); - } - - /// Gives pointer to raw data (const). - inline const Real* Data() const { - return data_; - } - - /// Gives pointer to raw data (non-const). - inline Real* Data() { return data_; } - - /// Returns pointer to data for one row (non-const) - inline Real* RowData(MatrixIndexT i) { - KALDI_ASSERT(static_cast(i) < - static_cast(num_rows_)); - return data_ + i * stride_; - } - - /// Returns pointer to data for one row (const) - inline const Real* RowData(MatrixIndexT i) const { - KALDI_ASSERT(static_cast(i) < - static_cast(num_rows_)); - return data_ + i * stride_; - } - - /// Indexing operator, non-const - /// (only checks sizes if compiled with -DKALDI_PARANOID) - inline Real& operator() (MatrixIndexT r, MatrixIndexT c) { - KALDI_PARANOID_ASSERT(static_cast(r) < - static_cast(num_rows_) && - static_cast(c) < - static_cast(num_cols_)); - return *(data_ + r * stride_ + c); - } - /// Indexing operator, provided for ease of debugging (gdb doesn't work - /// with parenthesis operator). - Real &Index (MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } - - /// Indexing operator, const - /// (only checks sizes if compiled with -DKALDI_PARANOID) - inline const Real operator() (MatrixIndexT r, MatrixIndexT c) const { - KALDI_PARANOID_ASSERT(static_cast(r) < - static_cast(num_rows_) && - static_cast(c) < - static_cast(num_cols_)); - return *(data_ + r * stride_ + c); - } - - /* Basic setting-to-special values functions. */ - - /// Sets matrix to zero. - void SetZero(); - /// Sets all elements to a specific value. - void Set(Real); - /// Sets to zero, except ones along diagonal [for non-square matrices too] - - /// Copy given matrix. (no resize is done). - template - void CopyFromMat(const MatrixBase & M, - MatrixTransposeType trans = kNoTrans); - - /// Copy from compressed matrix. - //void CopyFromMat(const CompressedMatrix &M); - - /// Copy given tpmatrix. (no resize is done). - //template - //void CopyFromTp(const TpMatrix &M, - //MatrixTransposeType trans = kNoTrans); - - /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h - //template - //void CopyFromMat(const CuMatrixBase &M, - //MatrixTransposeType trans = kNoTrans); - - /// This function has two modes of operation. If v.Dim() == NumRows() * - /// NumCols(), then treats the vector as a row-by-row concatenation of a - /// matrix and copies to *this. - /// if v.Dim() == NumCols(), it sets each row of *this to a copy of v. - void CopyRowsFromVec(const VectorBase &v); - - /// This version of CopyRowsFromVec is implemented in ../cudamatrix/cu-vector.cc - //void CopyRowsFromVec(const CuVectorBase &v); - - template - void CopyRowsFromVec(const VectorBase &v); - - /// Copies vector into matrix, column-by-column. - /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); - /// this has two modes of operation. - void CopyColsFromVec(const VectorBase &v); - - /// Copy vector into specific column of matrix. - void CopyColFromVec(const VectorBase &v, const MatrixIndexT col); - /// Copy vector into specific row of matrix. - void CopyRowFromVec(const VectorBase &v, const MatrixIndexT row); - /// Copy vector into diagonal of matrix. - void CopyDiagFromVec(const VectorBase &v); - - /* Accessing of sub-parts of the matrix. */ - - /// Return specific row of matrix [const]. - inline const SubVector Row(MatrixIndexT i) const { - KALDI_ASSERT(static_cast(i) < - static_cast(num_rows_)); - return SubVector(data_ + (i * stride_), NumCols()); - } - - /// Return specific row of matrix. - inline SubVector Row(MatrixIndexT i) { - KALDI_ASSERT(static_cast(i) < - static_cast(num_rows_)); - return SubVector(data_ + (i * stride_), NumCols()); - } - - /// Return a sub-part of matrix. - inline SubMatrix Range(const MatrixIndexT row_offset, - const MatrixIndexT num_rows, - const MatrixIndexT col_offset, - const MatrixIndexT num_cols) const { - return SubMatrix(*this, row_offset, num_rows, - col_offset, num_cols); - } - inline SubMatrix RowRange(const MatrixIndexT row_offset, - const MatrixIndexT num_rows) const { - return SubMatrix(*this, row_offset, num_rows, 0, num_cols_); - } - inline SubMatrix ColRange(const MatrixIndexT col_offset, - const MatrixIndexT num_cols) const { - return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); - } - -/* - /// Returns sum of all elements in matrix. - Real Sum() const; - /// Returns trace of matrix. - Real Trace(bool check_square = true) const; - // If check_square = true, will crash if matrix is not square. - - /// Returns maximum element of matrix. - Real Max() const; - /// Returns minimum element of matrix. - Real Min() const; - - /// Element by element multiplication with a given matrix. - void MulElements(const MatrixBase &A); - - /// Divide each element by the corresponding element of a given matrix. - void DivElements(const MatrixBase &A); - - /// Multiply each element with a scalar value. - void Scale(Real alpha); - - /// Set, element-by-element, *this = max(*this, A) - void Max(const MatrixBase &A); - /// Set, element-by-element, *this = min(*this, A) - void Min(const MatrixBase &A); - - /// Equivalent to (*this) = (*this) * diag(scale). Scaling - /// each column by a scalar taken from that dimension of the vector. - void MulColsVec(const VectorBase &scale); - - /// Equivalent to (*this) = diag(scale) * (*this). Scaling - /// each row by a scalar taken from that dimension of the vector. - void MulRowsVec(const VectorBase &scale); - - /// Divide each row into src.NumCols() equal groups, and then scale i'th row's - /// j'th group of elements by src(i, j). Requires src.NumRows() == - /// this->NumRows() and this->NumCols() % src.NumCols() == 0. - void MulRowsGroupMat(const MatrixBase &src); - - /// Returns logdet of matrix. - Real LogDet(Real *det_sign = NULL) const; - - /// matrix inverse. - /// if inverse_needed = false, will fill matrix with garbage. - /// (only useful if logdet wanted). - void Invert(Real *log_det = NULL, Real *det_sign = NULL, - bool inverse_needed = true); - /// matrix inverse [double]. - /// if inverse_needed = false, will fill matrix with garbage - /// (only useful if logdet wanted). - /// Does inversion in double precision even if matrix was not double. - void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, - bool inverse_needed = true); -*/ - /// Inverts all the elements of the matrix - void InvertElements(); -/* - /// Transpose the matrix. This one is only - /// applicable to square matrices (the one in the - /// Matrix child class works also for non-square. - void Transpose(); - -*/ - /// Copies column r from column indices[r] of src. - /// As a special case, if indexes[i] == -1, sets column i to zero. - /// all elements of "indices" must be in [-1, src.NumCols()-1], - /// and src.NumRows() must equal this.NumRows() - void CopyCols(const MatrixBase &src, - const MatrixIndexT *indices); - - /// Copies row r from row indices[r] of src (does nothing - /// As a special case, if indexes[i] == -1, sets row i to zero. - /// all elements of "indices" must be in [-1, src.NumRows()-1], - /// and src.NumCols() must equal this.NumCols() - void CopyRows(const MatrixBase &src, - const MatrixIndexT *indices); - - /// Add column indices[r] of src to column r. - /// As a special case, if indexes[i] == -1, skip column i - /// indices.size() must equal this->NumCols(), - /// all elements of "reorder" must be in [-1, src.NumCols()-1], - /// and src.NumRows() must equal this.NumRows() - //void AddCols(const MatrixBase &src, - // const MatrixIndexT *indices); - - /// Copies row r of this matrix from an array of floats at the location given - /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. - /// Note: we are using "pointer to const pointer to const object" for "src", - /// because we may create "src" by calling Data() of const CuArray - void CopyRows(const Real *const *src); - - /// Copies row r of this matrix to the array of floats at the location given - /// by dst[r]. If dst[r] is NULL, does not copy anywhere. Requires that none - /// of the memory regions pointed to by the pointers in "dst" overlap (e.g. - /// none of the pointers should be the same). - void CopyToRows(Real *const *dst) const; - - /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). - /// If indexes[r] < 0, does not add anything. all elements of "indexes" must - /// be in [-1, src.NumRows()-1], and src.NumCols() must equal this.NumCols(). - // void AddRows(Real alpha, - // const MatrixBase &src, - // const MatrixIndexT *indexes); - - /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as the - /// beginning of a region of memory representing a vector of floats, of the - /// same length as this.NumCols(). If src[r] is NULL, does not add anything. - //void AddRows(Real alpha, const Real *const *src); - - /// For each row r of this matrix, adds it (times alpha) to the array of - /// floats at the location given by dst[r]. If dst[r] is NULL, does not do - /// anything for that row. Requires that none of the memory regions pointed - /// to by the pointers in "dst" overlap (e.g. none of the pointers should be - /// the same). - //void AddToRows(Real alpha, Real *const *dst) const; - - /// For each row i of *this, adds this->Row(i) to - /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. - /// Requires that all the indexes[i] that are >= 0 - /// be distinct, otherwise the behavior is undefined. - //void AddToRows(Real alpha, - // const MatrixIndexT *indexes, + public: + // so this child can access protected members of other instances. + friend class Matrix; + friend class SubMatrix; + // friend declarations for CUDA matrices (see ../cudamatrix/) + + /// Returns number of rows (or zero for empty matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for empty matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast(num_rows_) * static_cast(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real *Data() const { return data_; } + + /// Gives pointer to raw data (non-const). + inline Real *Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real *RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real *RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real &operator()(MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index(MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + + /// Copy given matrix. (no resize is done). + template + void CopyFromMat(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + // void CopyFromMat(const CompressedMatrix &M); + + /// Copy given tpmatrix. (no resize is done). + // template + // void CopyFromTp(const TpMatrix &M, + // MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + // template + // void CopyFromMat(const CuMatrixBase &M, + // MatrixTransposeType trans = kNoTrans); + + /// This function has two modes of operation. If v.Dim() == NumRows() * + /// NumCols(), then treats the vector as a row-by-row concatenation of a + /// matrix and copies to *this. + /// if v.Dim() == NumCols(), it sets each row of *this to a copy of v. + void CopyRowsFromVec(const VectorBase &v); + + /// This version of CopyRowsFromVec is implemented in + /// ../cudamatrix/cu-vector.cc + // void CopyRowsFromVec(const CuVectorBase &v); + + template + void CopyRowsFromVec(const VectorBase &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix( + *this, row_offset, num_rows, col_offset, num_cols); + } + inline SubMatrix RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); + } + + /* + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase &A); + /// Set, element-by-element, *this = min(*this, A) + void Min(const MatrixBase &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th + row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + */ + /// Inverts all the elements of the matrix + void InvertElements(); + /* + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + */ + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero. + /// all elements of "indices" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase &src, const MatrixIndexT *indices); + + /// Copies row r from row indices[r] of src (does nothing + /// As a special case, if indexes[i] == -1, sets row i to zero. + /// all elements of "indices" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase &src, const MatrixIndexT *indices); + + /// Add column indices[r] of src to column r. + /// As a special case, if indexes[i] == -1, skip column i + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + // void AddCols(const MatrixBase &src, + // const MatrixIndexT *indices); + + /// Copies row r of this matrix from an array of floats at the location + /// given + /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. + /// Note: we are using "pointer to const pointer to const object" for "src", + /// because we may create "src" by calling Data() of const CuArray + void CopyRows(const Real *const *src); + + /// Copies row r of this matrix to the array of floats at the location given + /// by dst[r]. If dst[r] is NULL, does not copy anywhere. Requires that + /// none + /// of the memory regions pointed to by the pointers in "dst" overlap (e.g. + /// none of the pointers should be the same). + void CopyToRows(Real *const *dst) const; + + /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). + /// If indexes[r] < 0, does not add anything. all elements of "indexes" must + /// be in [-1, src.NumRows()-1], and src.NumCols() must equal + /// this.NumCols(). + // void AddRows(Real alpha, + // const MatrixBase &src, + // const MatrixIndexT *indexes); + + /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as + /// the + /// beginning of a region of memory representing a vector of floats, of the + /// same length as this.NumCols(). If src[r] is NULL, does not add anything. + // void AddRows(Real alpha, const Real *const *src); + + /// For each row r of this matrix, adds it (times alpha) to the array of + /// floats at the location given by dst[r]. If dst[r] is NULL, does not do + /// anything for that row. Requires that none of the memory regions pointed + /// to by the pointers in "dst" overlap (e.g. none of the pointers should be + /// the same). + // void AddToRows(Real alpha, Real *const *dst) const; + + /// For each row i of *this, adds this->Row(i) to + /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. + /// Requires that all the indexes[i] that are >= 0 + /// be distinct, otherwise the behavior is undefined. + // void AddToRows(Real alpha, + // const MatrixIndexT *indexes, // MatrixBase *dst) const; -/* - inline void ApplyPow(Real power) { - this -> Pow(*this, power); - } - - - inline void ApplyPowAbs(Real power, bool include_sign=false) { - this -> PowAbs(*this, power, include_sign); - } - - inline void ApplyHeaviside() { - this -> Heaviside(*this); - } - - inline void ApplyFloor(Real floor_val) { - this -> Floor(*this, floor_val); - } - - inline void ApplyCeiling(Real ceiling_val) { - this -> Ceiling(*this, ceiling_val); - } - - inline void ApplyExp() { - this -> Exp(*this); - } - - inline void ApplyExpSpecial() { - this -> ExpSpecial(*this); - } - - inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { - this -> ExpLimited(*this, lower_limit, upper_limit); - } - - inline void ApplyLog() { - this -> Log(*this); - } -*/ - /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D - /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is - /// slightly complicated, due to the need for P to be real. In the symmetric - /// case D is diagonal and real, but in - /// the non-symmetric case there may be complex-conjugate pairs of eigenvalues. - /// In this case, for the equation (*this) = P D P^{-1} to hold, D must actually - /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If a - /// pair is lambda +- i*mu, D will have a corresponding 2x2 block - /// [lambda, mu; -mu, lambda]. - /// Note that if the input matrix (*this) is non-invertible, P may not be invertible - /// so in this case instead of the equation (*this) = P D P^{-1} holding, we have - /// instead (*this) P = P D. - /// - /// The non-member function CreateEigenvalueMatrix creates D from eigs_real and eigs_imag. - //void Eig(MatrixBase *P, - // VectorBase *eigs_real, + /* + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + } + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + } + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + } + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + } + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + } + + inline void ApplyExp() { + this -> Exp(*this); + } + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + } + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + } + + inline void ApplyLog() { + this -> Log(*this); + } + */ + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = + /// P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output + /// is + /// slightly complicated, due to the need for P to be real. In the + /// symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of + /// eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must + /// actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If + /// a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be + /// invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we + /// have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real + /// and eigs_imag. + // void Eig(MatrixBase *P, + // VectorBase *eigs_real, // VectorBase *eigs_imag) const; - /// The Power method attempts to take the matrix to a power using a method that - /// works in general for fractional and negative powers. The input matrix must - /// be invertible and have reasonable condition (or we don't guarantee the - /// results. The method is based on the eigenvalue decomposition. It will - /// return false and leave the matrix unchanged, if at entry the matrix had - /// real negative eigenvalues (or if it had zero eigenvalues and the power was - /// negative). -// bool Power(Real pow); - - /** Singular value decomposition - Major limitations: - For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we return - the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the - one on the left is rectangular. - - In Svd, *this = U*diag(S)*Vt. - Null pointers for U and/or Vt at input mean we do not want that output. We - expect that S.Dim() == m, U is either NULL or m by n, - and v is either NULL or n by n. - The singular values are not sorted (use SortSvd for that). */ - //void DestructiveSvd(VectorBase *s, MatrixBase *U, - // MatrixBase *Vt); // Destroys calling matrix. - - /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is already - /// transposed; the normal formulation is U diag(s) V^T. - /// Null pointers for U or V mean we don't want that output (this saves - /// compute). The singular values are not sorted (use SortSvd for that). - //void Svd(VectorBase *s, MatrixBase *U, - // MatrixBase *Vt) const; - /// Compute SVD but only retain the singular values. - //void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } - - - /// Returns smallest singular value. - //Real MinSingularValue() const { - // Vector tmp(std::min(NumRows(), NumCols())); - //Svd(&tmp); - //return tmp.Min(); - //} - - //void TestUninitialized() const; // This function is designed so that if any element - // if the matrix is uninitialized memory, valgrind will complain. - - /// Returns condition number by computing Svd. Works even if cols > rows. - /// Returns infinity if all singular values are zero. - /* - Real Cond() const; - - /// Returns true if matrix is Symmetric. - bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number - - /// Returns true if matrix is Diagonal. - bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number - - /// Returns true if the matrix is all zeros, except for ones on diagonal. (it - /// does not have to be square). More specifically, this function returns - /// false if for any i, j, (*this)(i, j) differs by more than cutoff from the - /// expression (i == j ? 1 : 0). - bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number - - /// Returns true if matrix is all zeros. - bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number - - /// Frobenius norm, which is the sqrt of sum of square elements. Same as Schatten 2-norm, - /// or just "2-norm". - Real FrobeniusNorm() const; - - /// Returns true if ((*this)-other).FrobeniusNorm() - /// <= tol * (*this).FrobeniusNorm(). - bool ApproxEqual(const MatrixBase &other, float tol = 0.01) const; - - /// Tests for exact equality. It's usually preferable to use ApproxEqual. - bool Equal(const MatrixBase &other) const; - - /// largest absolute value. - Real LargestAbsElem() const; // largest absolute value. - - /// Returns log(sum(exp())) without exp overflow - /// If prune > 0.0, it uses a pruning beam, discarding - /// terms less than (max - prune). Note: in future - /// we may change this so that if prune = 0.0, it takes - /// the max, so use -1 if you don't want to prune. - Real LogSumExp(Real prune = -1.0) const; - - /// Apply soft-max to the collection of all elements of the - /// matrix and return normalizer (log sum of exponentials). - Real ApplySoftMax(); - - /// Set each element to the sigmoid of the corresponding element of "src". - void Sigmoid(const MatrixBase &src); - - /// Sets each element to the Heaviside step function (x > 0 ? 1 : 0) of the - /// corresponding element in "src". Note: in general you can make different - /// choices for x = 0, but for now please leave it as it (i.e. returning zero) - /// because it affects the RectifiedLinearComponent in the neural net code. - void Heaviside(const MatrixBase &src); - - void Exp(const MatrixBase &src); - - void Pow(const MatrixBase &src, Real power); - - void Log(const MatrixBase &src); - - /// Apply power to the absolute value of each element. - /// If include_sign is true, the result will be multiplied with - /// the sign of the input value. - /// If the power is negative and the input to the power is zero, - /// The output will be set zero. If include_sign is true, it will - /// multiply the result by the sign of the input. - void PowAbs(const MatrixBase &src, Real power, bool include_sign=false); - - void Floor(const MatrixBase &src, Real floor_val); - - void Ceiling(const MatrixBase &src, Real ceiling_val); - - /// For each element x of the matrix, set it to - /// (x < 0 ? exp(x) : x + 1). This function is used - /// in our RNNLM training. - void ExpSpecial(const MatrixBase &src); - - /// This is equivalent to running: - /// Floor(src, lower_limit); - /// Ceiling(src, upper_limit); - /// Exp(src) - void ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit); - - /// Set each element to y = log(1 + exp(x)) - void SoftHinge(const MatrixBase &src); - - /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / p). - /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. - void GroupPnorm(const MatrixBase &src, Real power); - - /// Calculate derivatives for the GroupPnorm function above... - /// if "input" is the input to the GroupPnorm function above (i.e. the "src" variable), - /// and "output" is the result of the computation (i.e. the "this" of that function - /// call), and *this has the same dimension as "input", then it sets each element - /// of *this to the derivative d(output-elem)/d(input-elem) for each element of "input", where - /// "output-elem" is whichever element of output depends on that input element. - void GroupPnormDeriv(const MatrixBase &input, const MatrixBase &output, - Real power); - - /// Apply the function y(i) = (max_{j = i*G}^{(i+1)*G-1} x_j - /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. - void GroupMax(const MatrixBase &src); - - /// Calculate derivatives for the GroupMax function above, where - /// "input" is the input to the GroupMax function above (i.e. the "src" variable), - /// and "output" is the result of the computation (i.e. the "this" of that function - /// call), and *this must have the same dimension as "input". Each element - /// of *this will be set to 1 if the corresponding input equals the output of - /// the group, and 0 otherwise. The equals the function derivative where it is - /// defined (it's not defined where multiple inputs in the group are equal to the output). - void GroupMaxDeriv(const MatrixBase &input, const MatrixBase &output); - - /// Set each element to the tanh of the corresponding element of "src". - void Tanh(const MatrixBase &src); - - // Function used in backpropagating derivatives of the sigmoid function: - // element-by-element, set *this = diff * value * (1.0 - value). - void DiffSigmoid(const MatrixBase &value, - const MatrixBase &diff); - - // Function used in backpropagating derivatives of the tanh function: - // element-by-element, set *this = diff * (1.0 - value^2). - void DiffTanh(const MatrixBase &value, - const MatrixBase &diff); -*/ - /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive - * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an - * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not - * positive semi-definite (check_thresh controls how stringent the check is; - * set it to 2 to ensure it won't ever complain, but it will zero out negative - * dimensions in your matrix. - * - * Caution: if you want the eigenvalues, it may make more sense to convert to - * SpMatrix and use Eig() function there, which uses eigenvalue decomposition - * directly rather than SVD. - */ + /// The Power method attempts to take the matrix to a power using a method + /// that + /// works in general for fractional and negative powers. The input matrix + /// must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power + /// was + /// negative). + // bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we + return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. + We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + // void DestructiveSvd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is + /// already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + // void Svd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt) const; + /// Compute SVD but only retain the singular values. + // void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + // Real MinSingularValue() const { + // Vector tmp(std::min(NumRows(), NumCols())); + // Svd(&tmp); + // return tmp.Min(); + //} - /// stream read. - /// Use instead of stream<<*this, if you want to add to existing contents. - // Will throw exception on failure. - void Read(std::istream & in, bool binary); - /// write to stream. - void Write(std::ostream & out, bool binary) const; - - // Below is internal methods for Svd, user does not have to know about this. - protected: - - /// Initializer, callable only from child. - explicit MatrixBase(Real *data, MatrixIndexT cols, MatrixIndexT rows, MatrixIndexT stride) : - data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { - KALDI_ASSERT_IS_FLOATING_TYPE(Real); - } - - /// Initializer, callable only from child. - /// Empty initializer, for un-initialized matrix. - explicit MatrixBase(): data_(NULL) { - KALDI_ASSERT_IS_FLOATING_TYPE(Real); - } - - // Make sure pointers to MatrixBase cannot be deleted. - ~MatrixBase() { } - - /// A workaround that allows SubMatrix to get a pointer to non-const data - /// for const Matrix. Unfortunately C++ does not allow us to declare a - /// "public const" inheritance or anything like that, so it would require - /// a lot of work to make the SubMatrix class totally const-correct-- - /// we would have to override many of the Matrix functions. - inline Real* Data_workaround() const { - return data_; - } - - /// data memory area - Real* data_; - - /// these attributes store the real matrix size as it is stored in memory - /// including memalignment - MatrixIndexT num_cols_; /// < Number of columns - MatrixIndexT num_rows_; /// < Number of rows - /** True number of columns for the internal matrix. This number may differ - * from num_cols_ as memory alignment might be used. */ - MatrixIndexT stride_; - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); + // void TestUninitialized() const; // This function is designed so that if + // any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + /* + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. + (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from + the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as + Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase &src); + + /// Sets each element to the Heaviside step function (x > 0 ? 1 : 0) of the + /// corresponding element in "src". Note: in general you can make different + /// choices for x = 0, but for now please leave it as it (i.e. returning + zero) + /// because it affects the RectifiedLinearComponent in the neural net code. + void Heaviside(const MatrixBase &src); + + void Exp(const MatrixBase &src); + + void Pow(const MatrixBase &src, Real power); + + void Log(const MatrixBase &src); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const MatrixBase &src, Real power, bool + include_sign=false); + + void Floor(const MatrixBase &src, Real floor_val); + + void Ceiling(const MatrixBase &src, Real ceiling_val); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const MatrixBase &src); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const MatrixBase &src, Real lower_limit, Real + upper_limit); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / + p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupPnorm(const MatrixBase &src, Real power); + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this has the same dimension as "input", then it sets each + element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element + of "input", where + /// "output-elem" is whichever element of output depends on that input + element. + void GroupPnormDeriv(const MatrixBase &input, const MatrixBase + &output, + Real power); + + /// Apply the function y(i) = (max_{j = i*G}^{(i+1)*G-1} x_j + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupMax(const MatrixBase &src); + + /// Calculate derivatives for the GroupMax function above, where + /// "input" is the input to the GroupMax function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this must have the same dimension as "input". Each element + /// of *this will be set to 1 if the corresponding input equals the output + of + /// the group, and 0 otherwise. The equals the function derivative where it + is + /// defined (it's not defined where multiple inputs in the group are equal + to the output). + void GroupMaxDeriv(const MatrixBase &input, const MatrixBase + &output); + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase &value, + const MatrixBase &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase &value, + const MatrixBase &diff); + */ + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out + * negative + * dimensions in your matrix. + * + * Caution: if you want the eigenvalues, it may make more sense to convert + * to + * SpMatrix and use Eig() function there, which uses eigenvalue + * decomposition + * directly rather than SVD. + */ + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary); + /// write to stream. + void Write(std::ostream &out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. + protected: + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, + MatrixIndexT cols, + MatrixIndexT rows, + MatrixIndexT stride) + : data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase() : data_(NULL) { KALDI_ASSERT_IS_FLOATING_TYPE(Real); } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() {} + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real *Data_workaround() const { return data_; } + + /// data memory area + Real *data_; + + /// these attributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); }; /// A class for storing matrices. -template +template class Matrix : public MatrixBase { - public: - - /// Empty constructor. - Matrix(); - - /// Basic constructor. - Matrix(const MatrixIndexT r, const MatrixIndexT c, - MatrixResizeType resize_type = kSetZero, - MatrixStrideType stride_type = kDefaultStride): - MatrixBase() { Resize(r, c, resize_type, stride_type); } - - /// Swaps the contents of *this and *other. Shallow swap. - void Swap(Matrix *other); - - /// Constructor from any MatrixBase. Can also copy with transpose. - /// Allocates new memory. - explicit Matrix(const MatrixBase & M, - MatrixTransposeType trans = kNoTrans); + public: + /// Empty constructor. + Matrix(); + + /// Basic constructor. + Matrix(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { + Resize(r, c, resize_type, stride_type); + } + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix *other); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); - /// Same as above, but need to avoid default copy constructor. - Matrix(const Matrix & M); // (cannot make explicit) + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix &M); // (cannot make explicit) - /// Copy constructor: as above, but from another type. - template - explicit Matrix(const MatrixBase & M, + /// Copy constructor: as above, but from another type. + template + explicit Matrix(const MatrixBase &M, MatrixTransposeType trans = kNoTrans); - /// Copy constructor taking TpMatrix... - //template - //explicit Matrix(const TpMatrix & M, - //MatrixTransposeType trans = kNoTrans) : MatrixBase() { - //if (trans == kNoTrans) { - //Resize(M.NumRows(), M.NumCols(), kUndefined); - //this->CopyFromTp(M); + /// Copy constructor taking TpMatrix... + // template + // explicit Matrix(const TpMatrix & M, + // MatrixTransposeType trans = kNoTrans) : MatrixBase() { + // if (trans == kNoTrans) { + // Resize(M.NumRows(), M.NumCols(), kUndefined); + // this->CopyFromTp(M); //} else { - //Resize(M.NumCols(), M.NumRows(), kUndefined); - //this->CopyFromTp(M, kTrans); + // Resize(M.NumCols(), M.NumRows(), kUndefined); + // this->CopyFromTp(M, kTrans); + //} //} - //} - - /// read from stream. - // Unlike one in base, allows resizing. - void Read(std::istream & in, bool binary); - - /// Remove a specified row. - void RemoveRow(MatrixIndexT i); - - /// Transpose the matrix. Works for non-square - /// matrices as well as square ones. - //void Transpose(); - - /// Distructor to free matrices. - ~Matrix() { Destroy(); } - - /// Sets matrix to a specified size (zero is OK as long as both r and c are - /// zero). The value of the new data depends on resize_type: - /// -if kSetZero, the new data will be zero - /// -if kUndefined, the new data will be undefined - /// -if kCopyData, the new data will be the same as the old data in any - /// shared positions, and zero elsewhere. - /// - /// You can set stride_type to kStrideEqualNumCols to force the stride - /// to equal the number of columns; by default it is set so that the stride - /// in bytes is a multiple of 16. - /// - /// This function takes time proportional to the number of data elements. - void Resize(const MatrixIndexT r, - const MatrixIndexT c, - MatrixResizeType resize_type = kSetZero, - MatrixStrideType stride_type = kDefaultStride); - - /// Assignment operator that takes MatrixBase. - Matrix &operator = (const MatrixBase &other) { - if (MatrixBase::NumRows() != other.NumRows() || - MatrixBase::NumCols() != other.NumCols()) - Resize(other.NumRows(), other.NumCols(), kUndefined); - MatrixBase::CopyFromMat(other); - return *this; - } - - /// Assignment operator. Needed for inclusion in std::vector. - Matrix &operator = (const Matrix &other) { - if (MatrixBase::NumRows() != other.NumRows() || - MatrixBase::NumCols() != other.NumCols()) - Resize(other.NumRows(), other.NumCols(), kUndefined); - MatrixBase::CopyFromMat(other); - return *this; - } - - - private: - /// Deallocates memory and sets to empty matrix (dimension 0, 0). - void Destroy(); - - /// Init assumes the current class contents are invalid (i.e. junk or have - /// already been freed), and it sets the matrix to newly allocated memory with - /// the specified number of rows and columns. r == c == 0 is acceptable. The data - /// memory contents will be undefined. - void Init(const MatrixIndexT r, - const MatrixIndexT c, - const MatrixStrideType stride_type); + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream &in, bool binary); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + // void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// + /// You can set stride_type to kStrideEqualNumCols to force the stride + /// to equal the number of columns; by default it is set so that the stride + /// in bytes is a multiple of 16. + /// + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride); + + /// Assignment operator that takes MatrixBase. + Matrix &operator=(const MatrixBase &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix &operator=(const Matrix &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory + /// with + /// the specified number of rows and columns. r == c == 0 is acceptable. + /// The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c, + const MatrixStrideType stride_type); }; /// @} end "addtogroup matrix_group" @@ -710,38 +756,38 @@ class Matrix : public MatrixBase { /// A structure containing the HTK header. /// [TODO: change the style of the variables to Kaldi-compliant] -template +template class SubMatrix : public MatrixBase { - public: - // Initialize a SubMatrix from part of a matrix; this is - // a bit like A(b:c, d:e) in Matlab. - // This initializer is against the proper semantics of "const", since - // SubMatrix can change its contents. It would be hard to implement - // a "const-safe" version of this class. - SubMatrix(const MatrixBase& T, - const MatrixIndexT ro, // row offset, 0 < ro < NumRows() - const MatrixIndexT r, // number of rows, r > 0 - const MatrixIndexT co, // column offset, 0 < co < NumCols() - const MatrixIndexT c); // number of columns, c > 0 - - // This initializer is mostly intended for use in CuMatrix and related - // classes. Be careful! - SubMatrix(Real *data, - MatrixIndexT num_rows, - MatrixIndexT num_cols, - MatrixIndexT stride); - - ~SubMatrix() {} - - /// This type of constructor is needed for Range() to work [in Matrix base - /// class]. Cannot make it explicit. - SubMatrix (const SubMatrix &other): - MatrixBase (other.data_, other.num_cols_, other.num_rows_, - other.stride_) {} - - private: - /// Disallow assignment. - SubMatrix &operator = (const SubMatrix &other); + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase &T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix(const SubMatrix &other) + : MatrixBase( + other.data_, other.num_cols_, other.num_rows_, other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix &operator=(const SubMatrix &other); }; /// @} End of "addtogroup matrix_funcs_io". @@ -794,25 +840,33 @@ Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, /// the same as U->NumCols(), and we sort s from greatest to least absolute /// value (if sort_on_absolute_value == true) or greatest to least value /// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it -/// exists, around in the same way. Note: the "absolute value" part won't matter +/// exists, around in the same way. Note: the "absolute value" part won't +matter /// if this is an actual SVD, since singular values are non-negative. template void SortSvd(VectorBase *s, MatrixBase *U, MatrixBase* Vt = NULL, bool sort_on_absolute_value = true); -/// Creates the eigenvalue matrix D that is part of the decomposition used Matrix::Eig. +/// Creates the eigenvalue matrix D that is part of the decomposition used +Matrix::Eig. /// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 -/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a corresponding +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a +corresponding /// 2x2 block [lambda, mu; -mu, lambda]. -/// This function will throw if any complex eigenvalues are not in complex conjugate +/// This function will throw if any complex eigenvalues are not in complex +conjugate /// pairs (or the members of such pairs are not consecutively numbered). template -void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase &imag, +void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase +&imag, MatrixBase *D); -/// The following function is used in Matrix::Power, and separately tested, so we -/// declare it here mainly for the testing code to see. It takes a complex value to -/// a power using a method that will work for noninteger powers (but will fail if the +/// The following function is used in Matrix::Power, and separately tested, so +we +/// declare it here mainly for the testing code to see. It takes a complex +value to +/// a power using a method that will work for noninteger powers (but will fail +if the /// complex value is real and negative). template bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); @@ -823,19 +877,19 @@ bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); /// \addtogroup matrix_funcs_io /// @{ -template -std::ostream & operator << (std::ostream & Out, const MatrixBase & M); +template +std::ostream &operator<<(std::ostream &Out, const MatrixBase &M); -template -std::istream & operator >> (std::istream & In, MatrixBase & M); +template +std::istream &operator>>(std::istream &In, MatrixBase &M); // The Matrix read allows resizing, so we override the MatrixBase one. -template -std::istream & operator >> (std::istream & In, Matrix & M); +template +std::istream &operator>>(std::istream &In, Matrix &M); -template +template bool SameDim(const MatrixBase &M, const MatrixBase &N) { - return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); } /// @} end of \addtogroup matrix_funcs_io @@ -844,7 +898,6 @@ bool SameDim(const MatrixBase &M, const MatrixBase &N) { } // namespace kaldi - // we need to include the implementation and some // template specializations. #include "matrix/kaldi-matrix-inl.h" diff --git a/runtime/engine/common/matrix/kaldi-vector-inl.h b/runtime/engine/common/matrix/kaldi-vector-inl.h index 82620276..b3075e59 100644 --- a/runtime/engine/common/matrix/kaldi-vector-inl.h +++ b/runtime/engine/common/matrix/kaldi-vector-inl.h @@ -26,32 +26,33 @@ namespace kaldi { -template -std::ostream & operator << (std::ostream &os, const VectorBase &rv) { - rv.Write(os, false); - return os; +template +std::ostream &operator<<(std::ostream &os, const VectorBase &rv) { + rv.Write(os, false); + return os; } -template -std::istream &operator >> (std::istream &is, VectorBase &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, VectorBase &rv) { + rv.Read(is, false); + return is; } -template -std::istream &operator >> (std::istream &is, Vector &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, Vector &rv) { + rv.Read(is, false); + return is; } -//template<> -//template<> -//void VectorBase::AddVec(const float alpha, const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const float alpha, const VectorBase +// &rv); -//template<> -//template<> -//void VectorBase::AddVec(const double alpha, - //const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const double alpha, +// const VectorBase &rv); } // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc index 9f2bd08e..90817f0d 100644 --- a/runtime/engine/common/matrix/kaldi-vector.cc +++ b/runtime/engine/common/matrix/kaldi-vector.cc @@ -23,80 +23,85 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. +#include "matrix/kaldi-vector.h" #include #include -#include "matrix/kaldi-vector.h" #include "matrix/kaldi-matrix.h" namespace kaldi { -template +template inline void Vector::Init(const MatrixIndexT dim) { - KALDI_ASSERT(dim >= 0); - if (dim == 0) { - this->dim_ = 0; - this->data_ = NULL; - return; - } - MatrixIndexT size; - void *data; - void *free_data; + KALDI_ASSERT(dim >= 0); + if (dim == 0) { + this->dim_ = 0; + this->data_ = NULL; + return; + } + MatrixIndexT size; + void *data; + void *free_data; - size = dim * sizeof(Real); + size = dim * sizeof(Real); - if ((data = KALDI_MEMALIGN(16, size, &free_data)) != NULL) { - this->data_ = static_cast (data); - this->dim_ = dim; - } else { - throw std::bad_alloc(); - } + if ((data = KALDI_MEMALIGN(16, size, &free_data)) != NULL) { + this->data_ = static_cast(data); + this->dim_ = dim; + } else { + throw std::bad_alloc(); + } } -template -void Vector::Resize(const MatrixIndexT dim, MatrixResizeType resize_type) { - - // the next block uses recursion to handle what we have to do if - // resize_type == kCopyData. - if (resize_type == kCopyData) { - if (this->data_ == NULL || dim == 0) resize_type = kSetZero; // nothing to copy. - else if (this->dim_ == dim) { return; } // nothing to do. - else { - // set tmp to a vector of the desired size. - Vector tmp(dim, kUndefined); - if (dim > this->dim_) { - memcpy(tmp.data_, this->data_, sizeof(Real)*this->dim_); - memset(tmp.data_+this->dim_, 0, sizeof(Real)*(dim-this->dim_)); - } else { - memcpy(tmp.data_, this->data_, sizeof(Real)*dim); - } - tmp.Swap(this); - // and now let tmp go out of scope, deleting what was in *this. - return; +template +void Vector::Resize(const MatrixIndexT dim, + MatrixResizeType resize_type) { + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || dim == 0) + resize_type = kSetZero; // nothing to copy. + else if (this->dim_ == dim) { + return; + } // nothing to do. + else { + // set tmp to a vector of the desired size. + Vector tmp(dim, kUndefined); + if (dim > this->dim_) { + memcpy(tmp.data_, this->data_, sizeof(Real) * this->dim_); + memset(tmp.data_ + this->dim_, + 0, + sizeof(Real) * (dim - this->dim_)); + } else { + memcpy(tmp.data_, this->data_, sizeof(Real) * dim); + } + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } } - } - // At this point, resize_type == kSetZero or kUndefined. + // At this point, resize_type == kSetZero or kUndefined. - if (this->data_ != NULL) { - if (this->dim_ == dim) { - if (resize_type == kSetZero) this->SetZero(); - return; - } else { - Destroy(); + if (this->data_ != NULL) { + if (this->dim_ == dim) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else { + Destroy(); + } } - } - Init(dim); - if (resize_type == kSetZero) this->SetZero(); + Init(dim); + if (resize_type == kSetZero) this->SetZero(); } /// Copy data from another vector -template +template void VectorBase::CopyFromVec(const VectorBase &v) { - KALDI_ASSERT(Dim() == v.Dim()); - if (data_ != v.data_) { - std::memcpy(this->data_, v.data_, dim_ * sizeof(Real)); - } + KALDI_ASSERT(Dim() == v.Dim()); + if (data_ != v.data_) { + std::memcpy(this->data_, v.data_, dim_ * sizeof(Real)); + } } /* @@ -107,10 +112,14 @@ void VectorBase::CopyFromPacked(const PackedMatrix& M) { this->CopyFromVec(v); } // instantiate the template. -template void VectorBase::CopyFromPacked(const PackedMatrix &other); -template void VectorBase::CopyFromPacked(const PackedMatrix &other); -template void VectorBase::CopyFromPacked(const PackedMatrix &other); -template void VectorBase::CopyFromPacked(const PackedMatrix &other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); /// Load data into the vector template @@ -119,50 +128,48 @@ void VectorBase::CopyFromPtr(const Real *data, MatrixIndexT sz) { std::memcpy(this->data_, data, Dim() * sizeof(Real)); }*/ -template -template +template +template void VectorBase::CopyFromVec(const VectorBase &other) { - KALDI_ASSERT(dim_ == other.Dim()); - Real * __restrict__ ptr = data_; - const OtherReal * __restrict__ other_ptr = other.Data(); - for (MatrixIndexT i = 0; i < dim_; i++) - ptr[i] = other_ptr[i]; + KALDI_ASSERT(dim_ == other.Dim()); + Real *__restrict__ ptr = data_; + const OtherReal *__restrict__ other_ptr = other.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) ptr[i] = other_ptr[i]; } template void VectorBase::CopyFromVec(const VectorBase &other); template void VectorBase::CopyFromVec(const VectorBase &other); // Remove element from the vector. The vector is not reallocated -template +template void Vector::RemoveElement(MatrixIndexT i) { - KALDI_ASSERT(i < this->dim_ && "Access out of vector"); - for (MatrixIndexT j = i + 1; j < this->dim_; j++) - this->data_[j-1] = this->data_[j]; - this->dim_--; + KALDI_ASSERT(i < this->dim_ && "Access out of vector"); + for (MatrixIndexT j = i + 1; j < this->dim_; j++) + this->data_[j - 1] = this->data_[j]; + this->dim_--; } /// Deallocates memory and sets object to empty vector. -template +template void Vector::Destroy() { - /// we need to free the data block if it was defined - if (this->data_ != NULL) - KALDI_MEMALIGN_FREE(this->data_); - this->data_ = NULL; - this->dim_ = 0; + /// we need to free the data block if it was defined + if (this->data_ != NULL) KALDI_MEMALIGN_FREE(this->data_); + this->data_ = NULL; + this->dim_ = 0; } -template +template void VectorBase::SetZero() { - std::memset(data_, 0, dim_ * sizeof(Real)); + std::memset(data_, 0, dim_ * sizeof(Real)); } -template +template bool VectorBase::IsZero(Real cutoff) const { - Real abs_max = 0.0; - for (MatrixIndexT i = 0; i < Dim(); i++) - abs_max = std::max(std::abs(data_[i]), abs_max); - return (abs_max <= cutoff); + Real abs_max = 0.0; + for (MatrixIndexT i = 0; i < Dim(); i++) + abs_max = std::max(std::abs(data_[i]), abs_max); + return (abs_max <= cutoff); } /* @@ -201,104 +208,107 @@ MatrixIndexT VectorBase::RandCategorical() const { // returns exactly 1, or due to roundoff. }*/ -template +template void VectorBase::Set(Real f) { - // Why not use memset here? - // The basic unit of memset is a byte. - // If f != 0 and sizeof(Real) > 1, then we cannot use memset. - if (f == 0) { - this->SetZero(); // calls std::memset - } else { - for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = f; } - } + // Why not use memset here? + // The basic unit of memset is a byte. + // If f != 0 and sizeof(Real) > 1, then we cannot use memset. + if (f == 0) { + this->SetZero(); // calls std::memset + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = f; + } + } } -template +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { - KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); - Real *inc_data = data_; - const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); - if (mat.Stride() == mat.NumCols()) { - memcpy(inc_data, mat.Data(), cols*rows*sizeof(Real)); - } else { - for (MatrixIndexT i = 0; i < rows; i++) { - // copy the data to the propper position - memcpy(inc_data, mat.RowData(i), cols * sizeof(Real)); - // set new copy position - inc_data += cols; + if (mat.Stride() == mat.NumCols()) { + memcpy(inc_data, mat.Data(), cols * rows * sizeof(Real)); + } else { + for (MatrixIndexT i = 0; i < rows; i++) { + // copy the data to the propper position + memcpy(inc_data, mat.RowData(i), cols * sizeof(Real)); + // set new copy position + inc_data += cols; + } } - } } -template -template +template +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { - KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); - Real *vec_data = data_; - const MatrixIndexT cols = mat.NumCols(), - rows = mat.NumRows(); - - for (MatrixIndexT i = 0; i < rows; i++) { - const OtherReal *mat_row = mat.RowData(i); - for (MatrixIndexT j = 0; j < cols; j++) { - vec_data[j] = static_cast(mat_row[j]); + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + Real *vec_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + for (MatrixIndexT i = 0; i < rows; i++) { + const OtherReal *mat_row = mat.RowData(i); + for (MatrixIndexT j = 0; j < cols; j++) { + vec_data[j] = static_cast(mat_row[j]); + } + vec_data += cols; } - vec_data += cols; - } } -template -void VectorBase::CopyRowsFromMat(const MatrixBase &mat); -template -void VectorBase::CopyRowsFromMat(const MatrixBase &mat); +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); -template +template void VectorBase::CopyColsFromMat(const MatrixBase &mat) { - KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); - Real* inc_data = data_; - const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(), stride = mat.Stride(); - const Real *mat_inc_data = mat.Data(); + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(), + stride = mat.Stride(); + const Real *mat_inc_data = mat.Data(); - for (MatrixIndexT i = 0; i < cols; i++) { - for (MatrixIndexT j = 0; j < rows; j++) { - inc_data[j] = mat_inc_data[j*stride]; + for (MatrixIndexT i = 0; i < cols; i++) { + for (MatrixIndexT j = 0; j < rows; j++) { + inc_data[j] = mat_inc_data[j * stride]; + } + mat_inc_data++; + inc_data += rows; } - mat_inc_data++; - inc_data += rows; - } } -template -void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row) { - KALDI_ASSERT(row < mat.NumRows()); - KALDI_ASSERT(dim_ == mat.NumCols()); - const Real *mat_row = mat.RowData(row); - memcpy(data_, mat_row, sizeof(Real)*dim_); +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const Real *mat_row = mat.RowData(row); + memcpy(data_, mat_row, sizeof(Real) * dim_); } -template -template -void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row) { - KALDI_ASSERT(row < mat.NumRows()); - KALDI_ASSERT(dim_ == mat.NumCols()); - const OtherReal *mat_row = mat.RowData(row); - for (MatrixIndexT i = 0; i < dim_; i++) - data_[i] = static_cast(mat_row[i]); +template +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const OtherReal *mat_row = mat.RowData(row); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = static_cast(mat_row[i]); } -template -void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row); -template -void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row); +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); /* template template -void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT row) { +void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT +row) { KALDI_ASSERT(row < sp.NumRows()); KALDI_ASSERT(dim_ == sp.NumCols()); @@ -313,13 +323,17 @@ void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT } template -void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); template -void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); template -void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); template -void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); // takes absolute value of the elements to a power. // Throws exception if could not (but only for power != 1 and power != 2). @@ -333,7 +347,8 @@ void VectorBase::ApplyPowAbs(Real power, bool include_sign) { data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * data_[i] * data_[i]; } else if (power == 0.5) { for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::sqrt(std::abs(data_[i])); + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +std::sqrt(std::abs(data_[i])); } } else if (power < 0.0) { for (MatrixIndexT i = 0; i < dim_; i++) { @@ -346,7 +361,8 @@ void VectorBase::ApplyPowAbs(Real power, bool include_sign) { } } else { for (MatrixIndexT i = 0; i < dim_; i++) { - data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * pow(std::abs(data_[i]), power); + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +pow(std::abs(data_[i]), power); if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. KALDI_ERR << "Could not raise element " << i << "to power " << power << ": returned value = " << data_[i]; @@ -401,7 +417,8 @@ Real VectorBase::Norm(Real p) const { } template -bool VectorBase::ApproxEqual(const VectorBase &other, float tol) const { +bool VectorBase::ApproxEqual(const VectorBase &other, float tol) +const { if (dim_ != other.dim_) KALDI_ERR << "ApproxEqual: size mismatch " << dim_ << " vs. " << other.dim_; KALDI_ASSERT(tol >= 0.0); @@ -499,677 +516,716 @@ Real VectorBase::Min(MatrixIndexT *index_out) const { }*/ -template -template -void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col) { - KALDI_ASSERT(col < mat.NumCols()); - KALDI_ASSERT(dim_ == mat.NumRows()); - for (MatrixIndexT i = 0; i < dim_; i++) - data_[i] = mat(i, col); - // can't do this very efficiently so don't really bother. could improve this though. +template +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col) { + KALDI_ASSERT(col < mat.NumCols()); + KALDI_ASSERT(dim_ == mat.NumRows()); + for (MatrixIndexT i = 0; i < dim_; i++) data_[i] = mat(i, col); + // can't do this very efficiently so don't really bother. could improve this + // though. } // instantiate the template above. -template -void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); -template -void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); -template -void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); -template -void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); - -//template -//void VectorBase::CopyDiagFromMat(const MatrixBase &M) { - //KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); - //cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); -//} - -//template -//void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { - //KALDI_ASSERT(dim_ == M.NumCols()); - //for (MatrixIndexT i = 0; i < dim_; i++) - //data_[i] = M(i, i); - //// could make this more efficient. -//} - -//template -//Real VectorBase::Sum() const { - //// Do a dot-product with a size-1 array with a stride of 0 to - //// implement sum. This allows us to access SIMD operations in a - //// cross-platform way via your BLAS library. - //Real one(1); - //return cblas_Xdot(dim_, data_, 1, &one, 0); -//} - -//template -//Real VectorBase::SumLog() const { - //double sum_log = 0.0; - //double prod = 1.0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //prod *= data_[i]; - //// Possible future work (arnab): change these magic values to pre-defined - //// constants - //if (prod < 1.0e-10 || prod > 1.0e+10) { - //sum_log += Log(prod); - //prod = 1.0; - //} - //} - //if (prod != 1.0) sum_log += Log(prod); - //return sum_log; -//} - -//template -//void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, Real beta) { - //KALDI_ASSERT(dim_ == M.NumCols()); - //MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; - //Real *data = data_; - - //// implement the function according to a dimension cutoff for computation efficiency - //if (num_rows <= 64) { - //cblas_Xscal(dim, beta, data, 1); - //const Real *m_data = M.Data(); - //for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) - //cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); - - //} else { - //Vector ones(M.NumRows()); - //ones.Set(1.0); - //this->AddMatVec(alpha, M, kTrans, ones, beta); - //} -//} - -//template -//void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, Real beta) { - //KALDI_ASSERT(dim_ == M.NumRows()); - //MatrixIndexT num_cols = M.NumCols(); - - //// implement the function according to a dimension cutoff for computation efficiency - //if (num_cols <= 64) { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //double sum = 0.0; - //const Real *src = M.RowData(i); - //for (MatrixIndexT j = 0; j < num_cols; j++) - //sum += src[j]; - //data_[i] = alpha * sum + beta * data_[i]; - //} - //} else { - //Vector ones(M.NumCols()); - //ones.Set(1.0); - //this->AddMatVec(alpha, M, kNoTrans, ones, beta); - //} -//} - -//template -//Real VectorBase::LogSumExp(Real prune) const { - //Real sum; - //if (sizeof(sum) == 8) sum = kLogZeroDouble; - //else sum = kLogZeroFloat; - //Real max_elem = Max(), cutoff; - //if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; - //else cutoff = max_elem + kMinLogDiffDouble; - //if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... - //cutoff = max_elem - prune; - - //double sum_relto_max_elem = 0.0; - - //for (MatrixIndexT i = 0; i < dim_; i++) { - //BaseFloat f = data_[i]; - //if (f >= cutoff) - //sum_relto_max_elem += Exp(f - max_elem); - //} - //return max_elem + Log(sum_relto_max_elem); -//} - -//template -//void VectorBase::InvertElements() { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = static_cast(1 / data_[i]); - //} -//} - -//template -//void VectorBase::ApplyLog() { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //if (data_[i] < 0.0) - //KALDI_ERR << "Trying to take log of a negative number."; - //data_[i] = Log(data_[i]); - //} -//} - -//template -//void VectorBase::ApplyLogAndCopy(const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.Dim()); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = Log(v(i)); - //} -//} - -//template -//void VectorBase::ApplyExp() { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = Exp(data_[i]); - //} -//} - -//template -//void VectorBase::ApplyAbs() { - //for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } -//} - -//template -//void VectorBase::Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count) { - //KALDI_ASSERT(dim_ == v.dim_); - //if (floored_count == nullptr) { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = std::max(v.data_[i], floor_val); - //} - //} else { - //MatrixIndexT num_floored = 0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //if (v.data_[i] < floor_val) { - //data_[i] = floor_val; - //num_floored++; - //} else { - //data_[i] = v.data_[i]; - //} - //} - //*floored_count = num_floored; - //} -//} - -//template -//void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count) { - //KALDI_ASSERT(dim_ == v.dim_); - //if (ceiled_count == nullptr) { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = std::min(v.data_[i], ceil_val); - //} - //} else { - //MatrixIndexT num_changed = 0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //if (v.data_[i] > ceil_val) { - //data_[i] = ceil_val; - //num_changed++; - //} else { - //data_[i] = v.data_[i]; - //} - //} - //*ceiled_count = num_changed; - //} -//} - -//template -//MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) { - //KALDI_ASSERT(floor_vec.Dim() == dim_); - //MatrixIndexT num_floored = 0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //if (data_[i] < floor_vec(i)) { - //data_[i] = floor_vec(i); - //num_floored++; - //} - //} - //return num_floored; -//} - -//template -//Real VectorBase::ApplySoftMax() { - //Real max = this->Max(), sum = 0.0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //sum += (data_[i] = Exp(data_[i] - max)); - //} - //this->Scale(1.0 / sum); - //return max + Log(sum); -//} - -//template -//Real VectorBase::ApplyLogSoftMax() { - //Real max = this->Max(), sum = 0.0; - //for (MatrixIndexT i = 0; i < dim_; i++) { - //sum += Exp((data_[i] -= max)); - //} - //sum = Log(sum); - //this->Add(-1.0 * sum); - //return max + sum; +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); + +// template +// void VectorBase::CopyDiagFromMat(const MatrixBase &M) { +// KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); +// cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); +//} + +// template +// void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] = M(i, i); +//// could make this more efficient. +//} + +// template +// Real VectorBase::Sum() const { +//// Do a dot-product with a size-1 array with a stride of 0 to +//// implement sum. This allows us to access SIMD operations in a +//// cross-platform way via your BLAS library. +// Real one(1); +// return cblas_Xdot(dim_, data_, 1, &one, 0); +//} + +// template +// Real VectorBase::SumLog() const { +// double sum_log = 0.0; +// double prod = 1.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// prod *= data_[i]; +//// Possible future work (arnab): change these magic values to pre-defined +//// constants +// if (prod < 1.0e-10 || prod > 1.0e+10) { +// sum_log += Log(prod); +// prod = 1.0; +//} +//} +// if (prod != 1.0) sum_log += Log(prod); +// return sum_log; +//} + +// template +// void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; +// Real *data = data_; + +//// implement the function according to a dimension cutoff for computation +///efficiency +// if (num_rows <= 64) { +// cblas_Xscal(dim, beta, data, 1); +// const Real *m_data = M.Data(); +// for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) +// cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); + +//} else { +// Vector ones(M.NumRows()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kTrans, ones, beta); +//} +//} + +// template +// void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumRows()); +// MatrixIndexT num_cols = M.NumCols(); + +//// implement the function according to a dimension cutoff for computation +///efficiency +// if (num_cols <= 64) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// double sum = 0.0; +// const Real *src = M.RowData(i); +// for (MatrixIndexT j = 0; j < num_cols; j++) +// sum += src[j]; +// data_[i] = alpha * sum + beta * data_[i]; +//} +//} else { +// Vector ones(M.NumCols()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kNoTrans, ones, beta); +//} +//} + +// template +// Real VectorBase::LogSumExp(Real prune) const { +// Real sum; +// if (sizeof(sum) == 8) sum = kLogZeroDouble; +// else sum = kLogZeroFloat; +// Real max_elem = Max(), cutoff; +// if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; +// else cutoff = max_elem + kMinLogDiffDouble; +// if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... +// cutoff = max_elem - prune; + +// double sum_relto_max_elem = 0.0; + +// for (MatrixIndexT i = 0; i < dim_; i++) { +// BaseFloat f = data_[i]; +// if (f >= cutoff) +// sum_relto_max_elem += Exp(f - max_elem); +//} +// return max_elem + Log(sum_relto_max_elem); +//} + +// template +// void VectorBase::InvertElements() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = static_cast(1 / data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLog() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < 0.0) +// KALDI_ERR << "Trying to take log of a negative number."; +// data_[i] = Log(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLogAndCopy(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Log(v(i)); +//} +//} + +// template +// void VectorBase::ApplyExp() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Exp(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyAbs() { +// for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } +//} + +// template +// void VectorBase::Floor(const VectorBase &v, Real floor_val, +// MatrixIndexT *floored_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (floored_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::max(v.data_[i], floor_val); +//} +//} else { +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] < floor_val) { +// data_[i] = floor_val; +// num_floored++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*floored_count = num_floored; +//} +//} + +// template +// void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, +// MatrixIndexT *ceiled_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (ceiled_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::min(v.data_[i], ceil_val); +//} +//} else { +// MatrixIndexT num_changed = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] > ceil_val) { +// data_[i] = ceil_val; +// num_changed++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*ceiled_count = num_changed; +//} +//} + +// template +// MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) +// { +// KALDI_ASSERT(floor_vec.Dim() == dim_); +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < floor_vec(i)) { +// data_[i] = floor_vec(i); +// num_floored++; +//} +//} +// return num_floored; +//} + +// template +// Real VectorBase::ApplySoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += (data_[i] = Exp(data_[i] - max)); +//} +// this->Scale(1.0 / sum); +// return max + Log(sum); +//} + +// template +// Real VectorBase::ApplyLogSoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += Exp((data_[i] -= max)); +//} +// sum = Log(sum); +// this->Add(-1.0 * sum); +// return max + sum; //} //#ifdef HAVE_MKL -//template<> -//void VectorBase::Tanh(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //vsTanh(dim_, src.data_, data_); +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vsTanh(dim_, src.data_, data_); //} -//template<> -//void VectorBase::Tanh(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //vdTanh(dim_, src.data_, data_); +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vdTanh(dim_, src.data_, data_); //} //#else -//template -//void VectorBase::Tanh(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //Real x = src.data_[i]; - //if (x > 0.0) { - //Real inv_expx = Exp(-x); - //x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); - //} else { - //Real expx = Exp(x); - //x = 1.0 - 2.0 / (1.0 + expx * expx); - //} - //data_[i] = x; - //} +// template +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +// if (x > 0.0) { +// Real inv_expx = Exp(-x); +// x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); +//} else { +// Real expx = Exp(x); +// x = 1.0 - 2.0 / (1.0 + expx * expx); +//} +// data_[i] = x; +//} //} //#endif //#ifdef HAVE_MKL //// Implementing sigmoid based on tanh. -//template<> -//void VectorBase::Sigmoid(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //this->CopyFromVec(src); - //this->Scale(0.5); - //vsTanh(dim_, data_, data_); - //this->Add(1.0); - //this->Scale(0.5); -//} -//template<> -//void VectorBase::Sigmoid(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //this->CopyFromVec(src); - //this->Scale(0.5); - //vdTanh(dim_, data_, data_); - //this->Add(1.0); - //this->Scale(0.5); +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vsTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); +//} +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vdTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); //} //#else -//template -//void VectorBase::Sigmoid(const VectorBase &src) { - //KALDI_ASSERT(dim_ == src.dim_); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //Real x = src.data_[i]; - //// We aim to avoid floating-point overflow here. - //if (x > 0.0) { - //x = 1.0 / (1.0 + Exp(-x)); - //} else { - //Real ex = Exp(x); - //x = ex / (ex + 1.0); - //} - //data_[i] = x; - //} +// template +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +//// We aim to avoid floating-point overflow here. +// if (x > 0.0) { +// x = 1.0 / (1.0 + Exp(-x)); +//} else { +// Real ex = Exp(x); +// x = ex / (ex + 1.0); +//} +// data_[i] = x; +//} //} //#endif -//template -//void VectorBase::Add(Real c) { - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] += c; - //} +// template +// void VectorBase::Add(Real c) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] += c; +//} //} -//template -//void VectorBase::Scale(Real alpha) { - //cblas_Xscal(dim_, alpha, data_, 1); +// template +// void VectorBase::Scale(Real alpha) { +// cblas_Xscal(dim_, alpha, data_, 1); //} -//template -//void VectorBase::MulElements(const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.dim_); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] *= v.data_[i]; - //} +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= v.data_[i]; +//} //} -//template // Set each element to y = (x == orig ? changed : x). -//void VectorBase::ReplaceValue(Real orig, Real changed) { - //Real *data = data_; - //for (MatrixIndexT i = 0; i < dim_; i++) - //if (data[i] == orig) data[i] = changed; +// template // Set each element to y = (x == orig ? changed : +// x). +// void VectorBase::ReplaceValue(Real orig, Real changed) { +// Real *data = data_; +// for (MatrixIndexT i = 0; i < dim_; i++) +// if (data[i] == orig) data[i] = changed; //} -//template -//template -//void VectorBase::MulElements(const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.Dim()); - //const OtherReal *other_ptr = v.Data(); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] *= other_ptr[i]; - //} +// template +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= other_ptr[i]; +//} //} //// instantiate template. -//template -//void VectorBase::MulElements(const VectorBase &v); -//template -//void VectorBase::MulElements(const VectorBase &v); - - -//template -//void VectorBase::AddVecVec(Real alpha, const VectorBase &v, - //const VectorBase &r, Real beta) { - //KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); - //// We pretend that v is a band-diagonal matrix. - //KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); - //cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, - //r.data_, 1, beta, this->data_, 1); +// template +// void VectorBase::MulElements(const VectorBase &v); +// template +// void VectorBase::MulElements(const VectorBase &v); + + +// template +// void VectorBase::AddVecVec(Real alpha, const VectorBase &v, +// const VectorBase &r, Real beta) { +// KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); +//// We pretend that v is a band-diagonal matrix. +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); +// cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, +// r.data_, 1, beta, this->data_, 1); //} -//template -//void VectorBase::DivElements(const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.dim_); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] /= v.data_[i]; - //} +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= v.data_[i]; +//} //} -//template -//template -//void VectorBase::DivElements(const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.Dim()); - //const OtherReal *other_ptr = v.Data(); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] /= other_ptr[i]; - //} +// template +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= other_ptr[i]; +//} //} //// instantiate template. -//template -//void VectorBase::DivElements(const VectorBase &v); -//template -//void VectorBase::DivElements(const VectorBase &v); - -//template -//void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, - //const VectorBase &rr, Real beta) { - //KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); - //for (MatrixIndexT i = 0; i < dim_; i++) { - //data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; - //} -//} - -//template -//template -//void VectorBase::AddVec(const Real alpha, const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.dim_); - //// remove __restrict__ if it causes compilation problems. - //Real *__restrict__ data = data_; - //OtherReal *__restrict__ other_data = v.data_; - //MatrixIndexT dim = dim_; - //if (alpha != 1.0) - //for (MatrixIndexT i = 0; i < dim; i++) - //data[i] += alpha * other_data[i]; - //else - //for (MatrixIndexT i = 0; i < dim; i++) - //data[i] += other_data[i]; -//} - -//template -//void VectorBase::AddVec(const float alpha, const VectorBase &v); -//template -//void VectorBase::AddVec(const double alpha, const VectorBase &v); - -//template -//template -//void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.dim_); - //// remove __restrict__ if it causes compilation problems. - //Real *__restrict__ data = data_; - //OtherReal *__restrict__ other_data = v.data_; - //MatrixIndexT dim = dim_; - //if (alpha != 1.0) - //for (MatrixIndexT i = 0; i < dim; i++) - //data[i] += alpha * other_data[i] * other_data[i]; - //else - //for (MatrixIndexT i = 0; i < dim; i++) - //data[i] += other_data[i] * other_data[i]; -//} - -//template -//void VectorBase::AddVec2(const float alpha, const VectorBase &v); -//template -//void VectorBase::AddVec2(const double alpha, const VectorBase &v); +// template +// void VectorBase::DivElements(const VectorBase &v); +// template +// void VectorBase::DivElements(const VectorBase &v); + +// template +// void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, +// const VectorBase &rr, Real beta) { +// KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; +//} +//} +// template +// template +// void VectorBase::AddVec(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i]; +//} -template -void VectorBase::Read(std::istream &is, bool binary) { - // In order to avoid rewriting this, we just declare a Vector and - // use it to read the data, then copy. - Vector tmp; - tmp.Read(is, binary); - if (tmp.Dim() != Dim()) - KALDI_ERR << "VectorBase::Read, size mismatch " - << Dim() << " vs. " << tmp.Dim(); - CopyFromVec(tmp); +// template +// void VectorBase::AddVec(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec(const double alpha, const VectorBase +// &v); + +// template +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i] * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i] * other_data[i]; +//} + +// template +// void VectorBase::AddVec2(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec2(const double alpha, const VectorBase +// &v); + + +template +void VectorBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Vector and + // use it to read the data, then copy. + Vector tmp; + tmp.Read(is, binary); + if (tmp.Dim() != Dim()) + KALDI_ERR << "VectorBase::Read, size mismatch " << Dim() + << " vs. " << tmp.Dim(); + CopyFromVec(tmp); } -template -void Vector::Read(std::istream &is, bool binary) { - std::ostringstream specific_error; - MatrixIndexT pos_at_start = is.tellg(); - - if (binary) { - int peekval = Peek(is, binary); - const char *my_token = (sizeof(Real) == 4 ? "FV" : "DV"); - char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); - if (peekval == other_token_start) { // need to instantiate the other type to read it. - typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. - Vector other(this->Dim()); - other.Read(is, binary); // add is false at this point. - if (this->Dim() != other.Dim()) this->Resize(other.Dim()); - this->CopyFromVec(other); - return; - } - std::string token; - ReadToken(is, binary, &token); - if (token != my_token) { - if (token.length() > 20) token = token.substr(0, 17) + "..."; - specific_error << ": Expected token " << my_token << ", got " << token; - goto bad; - } - int32 size; - ReadBasicType(is, binary, &size); // throws on error. - if ((MatrixIndexT)size != this->Dim()) this->Resize(size); - if (size > 0) - is.read(reinterpret_cast(this->data_), sizeof(Real)*size); - if (is.fail()) { - specific_error << "Error reading vector data (binary mode); truncated " - "stream? (size = " << size << ")"; - goto bad; - } - return; - } else { // Text mode reading; format is " [ 1.1 2.0 3.4 ]\n" - std::string s; - is >> s; - // if ((s.compare("DV") == 0) || (s.compare("FV") == 0)) { // Back compatibility. - // is >> s; // get dimension - // is >> s; // get "[" - // } - if (is.fail()) { specific_error << "EOF while trying to read vector."; goto bad; } - if (s.compare("[]") == 0) { Resize(0); return; } // tolerate this variant. - if (s.compare("[")) { - if (s.length() > 20) s = s.substr(0, 17) + "..."; - specific_error << "Expected \"[\" but got " << s; - goto bad; - } - std::vector data; - while (1) { - int i = is.peek(); - if (i == '-' || (i >= '0' && i <= '9')) { // common cases first. - Real r; - is >> r; - if (is.fail()) { specific_error << "Failed to read number."; goto bad; } - if (! std::isspace(is.peek()) && is.peek() != ']') { - specific_error << "Expected whitespace after number."; goto bad; +template +void Vector::Read(std::istream &is, bool binary) { + std::ostringstream specific_error; + MatrixIndexT pos_at_start = is.tellg(); + + if (binary) { + int peekval = Peek(is, binary); + const char *my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Vector other(this->Dim()); + other.Read(is, binary); // add is false at this point. + if (this->Dim() != other.Dim()) this->Resize(other.Dim()); + this->CopyFromVec(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; } - data.push_back(r); - // But don't eat whitespace... we want to check that it's not newlines - // which would be valid only for a matrix. - } else if (i == ' ' || i == '\t') { - is.get(); - } else if (i == ']') { - is.get(); // eat the ']' - this->Resize(data.size()); - for (size_t j = 0; j < data.size(); j++) - this->data_[j] = data[j]; - i = is.peek(); - if (static_cast(i) == '\r') { - is.get(); - is.get(); // get \r\n (must eat what we wrote) - } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) + int32 size; + ReadBasicType(is, binary, &size); // throws on error. + if ((MatrixIndexT)size != this->Dim()) this->Resize(size); + if (size > 0) + is.read(reinterpret_cast(this->data_), sizeof(Real) * size); if (is.fail()) { - KALDI_WARN << "After end of vector data, read error."; - // we got the data we needed, so just warn for this error. + specific_error + << "Error reading vector data (binary mode); truncated " + "stream? (size = " + << size << ")"; + goto bad; } - return; // success. - } else if (i == -1) { - specific_error << "EOF while reading vector data."; - goto bad; - } else if (i == '\n' || i == '\r') { - specific_error << "Newline found while reading vector (maybe it's a matrix?)"; - goto bad; - } else { - is >> s; // read string. - if (!KALDI_STRCASECMP(s.c_str(), "inf") || - !KALDI_STRCASECMP(s.c_str(), "infinity")) { - data.push_back(std::numeric_limits::infinity()); - KALDI_WARN << "Reading infinite value into vector."; - } else if (!KALDI_STRCASECMP(s.c_str(), "nan")) { - data.push_back(std::numeric_limits::quiet_NaN()); - KALDI_WARN << "Reading NaN value into vector."; - } else { - if (s.length() > 20) s = s.substr(0, 17) + "..."; - specific_error << "Expecting numeric vector data, got " << s; - goto bad; + return; + } else { // Text mode reading; format is " [ 1.1 2.0 3.4 ]\n" + std::string s; + is >> s; + // if ((s.compare("DV") == 0) || (s.compare("FV") == 0)) { // Back + // compatibility. + // is >> s; // get dimension + // is >> s; // get "[" + // } + if (is.fail()) { + specific_error << "EOF while trying to read vector."; + goto bad; + } + if (s.compare("[]") == 0) { + Resize(0); + return; + } // tolerate this variant. + if (s.compare("[")) { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expected \"[\" but got " << s; + goto bad; + } + std::vector data; + while (1) { + int i = is.peek(); + if (i == '-' || (i >= '0' && i <= '9')) { // common cases first. + Real r; + is >> r; + if (is.fail()) { + specific_error << "Failed to read number."; + goto bad; + } + if (!std::isspace(is.peek()) && is.peek() != ']') { + specific_error << "Expected whitespace after number."; + goto bad; + } + data.push_back(r); + // But don't eat whitespace... we want to check that it's not + // newlines + // which would be valid only for a matrix. + } else if (i == ' ' || i == '\t') { + is.get(); + } else if (i == ']') { + is.get(); // eat the ']' + this->Resize(data.size()); + for (size_t j = 0; j < data.size(); j++) + this->data_[j] = data[j]; + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of vector data, read error."; + // we got the data we needed, so just warn for this error. + } + return; // success. + } else if (i == -1) { + specific_error << "EOF while reading vector data."; + goto bad; + } else if (i == '\n' || i == '\r') { + specific_error << "Newline found while reading vector (maybe " + "it's a matrix?)"; + goto bad; + } else { + is >> s; // read string. + if (!KALDI_STRCASECMP(s.c_str(), "inf") || + !KALDI_STRCASECMP(s.c_str(), "infinity")) { + data.push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into vector."; + } else if (!KALDI_STRCASECMP(s.c_str(), "nan")) { + data.push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into vector."; + } else { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expecting numeric vector data, got " + << s; + goto bad; + } + } } - } } - } - // we never reach this line (the while loop returns directly). +// we never reach this line (the while loop returns directly). bad: - KALDI_ERR << "Failed to read vector from stream. " << specific_error.str() - << " File position at start is " - << pos_at_start<<", currently "< -void VectorBase::Write(std::ostream & os, bool binary) const { - if (!os.good()) { - KALDI_ERR << "Failed to write vector to stream: stream not good"; - } - if (binary) { - std::string my_token = (sizeof(Real) == 4 ? "FV" : "DV"); - WriteToken(os, binary, my_token); - - int32 size = Dim(); // make the size 32-bit on disk. - KALDI_ASSERT(Dim() == (MatrixIndexT) size); - WriteBasicType(os, binary, size); - os.write(reinterpret_cast(Data()), sizeof(Real) * size); - } else { - os << " [ "; - for (MatrixIndexT i = 0; i < Dim(); i++) - os << (*this)(i) << " "; - os << "]\n"; - } - if (!os.good()) - KALDI_ERR << "Failed to write vector to stream"; +template +void VectorBase::Write(std::ostream &os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write vector to stream: stream not good"; + } + if (binary) { + std::string my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + WriteToken(os, binary, my_token); + + int32 size = Dim(); // make the size 32-bit on disk. + KALDI_ASSERT(Dim() == (MatrixIndexT)size); + WriteBasicType(os, binary, size); + os.write(reinterpret_cast(Data()), sizeof(Real) * size); + } else { + os << " [ "; + for (MatrixIndexT i = 0; i < Dim(); i++) os << (*this)(i) << " "; + os << "]\n"; + } + if (!os.good()) KALDI_ERR << "Failed to write vector to stream"; } -//template -//void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { - //KALDI_ASSERT(dim_ == v.dim_); - //for (MatrixIndexT i = 0; i < dim_; i++) - //data_[i] += alpha * v.data_[i] * v.data_[i]; +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] += alpha * v.data_[i] * v.data_[i]; //} //// this <-- beta*this + alpha*M*v. -//template -//void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, - //const MatrixTransposeType trans, - //const VectorBase &v, - //const Real beta) { - //KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); - //if (beta == 0.0) { - //if (&v != this) CopyFromVec(v); - //MulTp(M, trans); - //if (alpha != 1.0) Scale(alpha); - //} else { - //Vector tmp(v); - //tmp.MulTp(M, trans); - //if (beta != 1.0) Scale(beta); // *this <-- beta * *this - //AddVec(alpha, tmp); // *this += alpha * M * v - //} -//} - -//template -//Real VecMatVec(const VectorBase &v1, const MatrixBase &M, - //const VectorBase &v2) { - //KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); - //Vector vtmp(M.NumRows()); - //vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); - //return VecVec(v1, vtmp); -//} - -//template -//float VecMatVec(const VectorBase &v1, const MatrixBase &M, - //const VectorBase &v2); -//template -//double VecMatVec(const VectorBase &v1, const MatrixBase &M, - //const VectorBase &v2); +// template +// void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, +// const MatrixTransposeType trans, +// const VectorBase &v, +// const Real beta) { +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); +// if (beta == 0.0) { +// if (&v != this) CopyFromVec(v); +// MulTp(M, trans); +// if (alpha != 1.0) Scale(alpha); +//} else { +// Vector tmp(v); +// tmp.MulTp(M, trans); +// if (beta != 1.0) Scale(beta); // *this <-- beta * *this +// AddVec(alpha, tmp); // *this += alpha * M * v +//} +//} -template +// template +// Real VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2) { +// KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); +// Vector vtmp(M.NumRows()); +// vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); +// return VecVec(v1, vtmp); +//} + +// template +// float VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); +// template +// double VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); + +template void Vector::Swap(Vector *other) { - std::swap(this->data_, other->data_); - std::swap(this->dim_, other->dim_); + std::swap(this->data_, other->data_); + std::swap(this->dim_, other->dim_); } -//template -//void VectorBase::AddDiagMat2( - //Real alpha, const MatrixBase &M, - //MatrixTransposeType trans, Real beta) { - //if (trans == kNoTrans) { - //KALDI_ASSERT(this->dim_ == M.NumRows()); - //MatrixIndexT rows = this->dim_, cols = M.NumCols(), - //mat_stride = M.Stride(); - //Real *data = this->data_; - //const Real *mat_data = M.Data(); - //for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) - //*data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); - //} else { - //KALDI_ASSERT(this->dim_ == M.NumCols()); - //MatrixIndexT rows = M.NumRows(), cols = this->dim_, - //mat_stride = M.Stride(); - //Real *data = this->data_; - //const Real *mat_data = M.Data(); - //for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) - //*data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, - //mat_data, mat_stride); - //} -//} - -//template -//void VectorBase::AddDiagMatMat( - //Real alpha, - //const MatrixBase &M, MatrixTransposeType transM, - //const MatrixBase &N, MatrixTransposeType transN, - //Real beta) { - //MatrixIndexT dim = this->dim_, - //M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), - //N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); - //KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over - //MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; - //if (transM == kTrans) std::swap(M_row_stride, M_col_stride); - //MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; - //if (transN == kTrans) std::swap(N_row_stride, N_col_stride); - - //Real *data = this->data_; - //const Real *Mdata = M.Data(), *Ndata = N.Data(); - //for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += N_col_stride, data++) { - //*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, Ndata, N_row_stride); - //} +// template +// void VectorBase::AddDiagMat2( +// Real alpha, const MatrixBase &M, +// MatrixTransposeType trans, Real beta) { +// if (trans == kNoTrans) { +// KALDI_ASSERT(this->dim_ == M.NumRows()); +// MatrixIndexT rows = this->dim_, cols = M.NumCols(), +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) +//*data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); +//} else { +// KALDI_ASSERT(this->dim_ == M.NumCols()); +// MatrixIndexT rows = M.NumRows(), cols = this->dim_, +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) +//*data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, +// mat_data, mat_stride); +//} +//} + +// template +// void VectorBase::AddDiagMatMat( +// Real alpha, +// const MatrixBase &M, MatrixTransposeType transM, +// const MatrixBase &N, MatrixTransposeType transN, +// Real beta) { +// MatrixIndexT dim = this->dim_, +// M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), +// N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); +// KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over +// MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; +// if (transM == kTrans) std::swap(M_row_stride, M_col_stride); +// MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; +// if (transN == kTrans) std::swap(N_row_stride, N_col_stride); + +// Real *data = this->data_; +// const Real *Mdata = M.Data(), *Ndata = N.Data(); +// for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += +// N_col_stride, data++) { +//*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, +//Ndata, N_row_stride); +//} //} diff --git a/runtime/engine/common/matrix/kaldi-vector.h b/runtime/engine/common/matrix/kaldi-vector.h index 5bcbeda9..461e026d 100644 --- a/runtime/engine/common/matrix/kaldi-vector.h +++ b/runtime/engine/common/matrix/kaldi-vector.h @@ -37,265 +37,274 @@ namespace kaldi { /// Provides a vector abstraction class. /// This class provides a way to work with vectors in kaldi. /// It encapsulates basic operations and memory optimizations. -template +template class VectorBase { - public: - /// Set vector to all zeros. - void SetZero(); - - /// Returns true if matrix is all zeros. - bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number - - /// Set all members of a vector to a specified value. - void Set(Real f); - - /// Returns the dimension of the vector. - inline MatrixIndexT Dim() const { return dim_; } - - /// Returns the size in memory of the vector, in bytes. - inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } - - /// Returns a pointer to the start of the vector's data. - inline Real* Data() { return data_; } - - /// Returns a pointer to the start of the vector's data (const). - inline const Real* Data() const { return data_; } - - /// Indexing operator (const). - inline Real operator() (MatrixIndexT i) const { - KALDI_PARANOID_ASSERT(static_cast(i) < - static_cast(dim_)); - return *(data_ + i); - } - - /// Indexing operator (non-const). - inline Real & operator() (MatrixIndexT i) { - KALDI_PARANOID_ASSERT(static_cast(i) < - static_cast(dim_)); - return *(data_ + i); - } - - /** @brief Returns a sub-vector of a vector (a range of elements). - * @param o [in] Origin, 0 < o < Dim() - * @param l [in] Length 0 < l < Dim()-o - * @return A SubVector object that aliases the data of the Vector object. - * See @c SubVector class for details */ - SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { - return SubVector(*this, o, l); - } - - /** @brief Returns a const sub-vector of a vector (a range of elements). - * @param o [in] Origin, 0 < o < Dim() - * @param l [in] Length 0 < l < Dim()-o - * @return A SubVector object that aliases the data of the Vector object. - * See @c SubVector class for details */ - const SubVector Range(const MatrixIndexT o, - const MatrixIndexT l) const { - return SubVector(*this, o, l); - } - - /// Copy data from another vector (must match own size). - void CopyFromVec(const VectorBase &v); - - /// Copy data from another vector of different type (double vs. float) - template - void CopyFromVec(const VectorBase &v); - - /// Performs a row stack of the matrix M - void CopyRowsFromMat(const MatrixBase &M); - template - void CopyRowsFromMat(const MatrixBase &M); - - /// Performs a column stack of the matrix M - void CopyColsFromMat(const MatrixBase &M); - - /// Extracts a row of the matrix M. Could also do this with - /// this->Copy(M[row]). - void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); - /// Extracts a row of the matrix M with type conversion. - template - void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); - - /// Extracts a column of the matrix M. - template - void CopyColFromMat(const MatrixBase &M , MatrixIndexT col); - - /// Reads from C++ stream (option to add to existing contents). - /// Throws exception on failure - void Read(std::istream &in, bool binary); - - /// Writes to C++ stream (option to write in binary). - void Write(std::ostream &Out, bool binary) const; - - friend class VectorBase; - friend class VectorBase; - protected: - /// Destructor; does not deallocate memory, this is handled by child classes. - /// This destructor is protected so this object can only be - /// deleted via a child. - ~VectorBase() {} - - /// Empty initializer, corresponds to vector of zero size. - explicit VectorBase(): data_(NULL), dim_(0) { - KALDI_ASSERT_IS_FLOATING_TYPE(Real); - } - - /// data memory area - Real* data_; - /// dimension of vector - MatrixIndexT dim_; - KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); -}; // class VectorBase + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_ * sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real *Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real *Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator()(MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real &operator()(MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase &v); + + /// Copy data from another vector of different type (double vs. float) + template + void CopyFromVec(const VectorBase &v); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase &M); + template + void CopyRowsFromMat(const MatrixBase &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template + void CopyColFromMat(const MatrixBase &M, MatrixIndexT col); + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream &in, bool binary); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase; + friend class VectorBase; + + protected: + /// Destructor; does not deallocate memory, this is handled by child + /// classes. + /// This destructor is protected so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase() : data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// data memory area + Real *data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase /** @brief A class representing a vector. * * This class provides a way to work with vectors in kaldi. * It encapsulates basic operations and memory optimizations. */ -template -class Vector: public VectorBase { - public: - /// Constructor that takes no arguments. Initializes to empty. - Vector(): VectorBase() {} - - /// Constructor with specific size. Sets to all-zero by default - /// if set_zero == false, memory contents are undefined. - explicit Vector(const MatrixIndexT s, - MatrixResizeType resize_type = kSetZero) - : VectorBase() { Resize(s, resize_type); } - - /// Copy constructor from CUDA vector - /// This is defined in ../cudamatrix/cu-vector.h - //template - //explicit Vector(const CuVectorBase &cu); - - /// Copy constructor. The need for this is controversial. - Vector(const Vector &v) : VectorBase() { // (cannot be explicit) - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - - /// Copy-constructor from base-class, needed to copy from SubVector. - explicit Vector(const VectorBase &v) : VectorBase() { - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - - /// Type conversion constructor. - template - explicit Vector(const VectorBase &v): VectorBase() { - Resize(v.Dim(), kUndefined); - this->CopyFromVec(v); - } - -// Took this out since it is unsafe : Arnab -// /// Constructor from a pointer and a size; copies the data to a location -// /// it owns. -// Vector(const Real* Data, const MatrixIndexT s): VectorBase() { -// Resize(s); - // CopyFromPtr(Data, s); -// } - - - /// Swaps the contents of *this and *other. Shallow swap. - void Swap(Vector *other); - - /// Destructor. Deallocates memory. - ~Vector() { Destroy(); } - - /// Read function using C++ streams. Can also add to existing contents - /// of matrix. - void Read(std::istream &in, bool binary); - - /// Set vector to a specified size (can be zero). - /// The value of the new data depends on resize_type: - /// -if kSetZero, the new data will be zero - /// -if kUndefined, the new data will be undefined - /// -if kCopyData, the new data will be the same as the old data in any - /// shared positions, and zero elsewhere. - /// This function takes time proportional to the number of data elements. - void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); - - /// Remove one element and shifts later elements down. - void RemoveElement(MatrixIndexT i); - - /// Assignment operator. - Vector &operator = (const Vector &other) { - Resize(other.Dim(), kUndefined); - this->CopyFromVec(other); - return *this; - } - - /// Assignment operator that takes VectorBase. - Vector &operator = (const VectorBase &other) { - Resize(other.Dim(), kUndefined); - this->CopyFromVec(other); - return *this; - } - private: - /// Init assumes the current contents of the class are invalid (i.e. junk or - /// has already been freed), and it sets the vector to newly allocated memory - /// with the specified dimension. dim == 0 is acceptable. The memory contents - /// pointed to by data_ will be undefined. - void Init(const MatrixIndexT dim); - - /// Destroy function, called internally. - void Destroy(); - +template +class Vector : public VectorBase { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector() : VectorBase() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { + Resize(s, resize_type); + } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + // template + // explicit Vector(const CuVectorBase &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector &v) + : VectorBase() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + // Took this out since it is unsafe : Arnab + // /// Constructor from a pointer and a size; copies the data to a location + // /// it owns. + // Vector(const Real* Data, const MatrixIndexT s): VectorBase() { + // Resize(s); + // CopyFromPtr(Data, s); + // } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream &in, bool binary); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator. + Vector &operator=(const Vector &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector &operator=(const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated + /// memory + /// with the specified dimension. dim == 0 is acceptable. The memory + /// contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); }; /// Represents a non-allocating general vector which can be defined /// as a sub-vector of higher-level vector [or as the row of a matrix]. -template +template class SubVector : public VectorBase { - public: - /// Constructor from a Vector or SubVector. - /// SubVectors are not const-safe and it's very hard to make them - /// so for now we just give up. This function contains const_cast. - SubVector(const VectorBase &t, const MatrixIndexT origin, - const MatrixIndexT length) : VectorBase() { - // following assert equiv to origin>=0 && length>=0 && - // origin+length <= rt.dim_ - KALDI_ASSERT(static_cast(origin)+ - static_cast(length) <= - static_cast(t.Dim())); - VectorBase::data_ = const_cast (t.Data()+origin); - VectorBase::dim_ = length; - } - - /// This constructor initializes the vector to point at the contents - /// of this packed matrix (SpMatrix or TpMatrix). - // SubVector(const PackedMatrix &M) { - //VectorBase::data_ = const_cast (M.Data()); - //VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; - //} - - /// Copy constructor - SubVector(const SubVector &other) : VectorBase () { - // this copy constructor needed for Range() to work in base class. - VectorBase::data_ = other.data_; - VectorBase::dim_ = other.dim_; - } - - /// Constructor from a pointer to memory and a length. Keeps a pointer - /// to the data but does not take ownership (will never delete). - /// Caution: this constructor enables you to evade const constraints. - SubVector(const Real *data, MatrixIndexT length) : VectorBase () { - VectorBase::data_ = const_cast(data); - VectorBase::dim_ = length; - } - - /// This operation does not preserve const-ness, so be careful. - SubVector(const MatrixBase &matrix, MatrixIndexT row) { - VectorBase::data_ = const_cast(matrix.RowData(row)); - VectorBase::dim_ = matrix.NumCols(); - } - - ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). - - private: - /// Disallow assignment operator. - SubVector & operator = (const SubVector &other) {} + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase &t, + const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast(origin) + + static_cast(length) <= + static_cast(t.Dim())); + VectorBase::data_ = const_cast(t.Data() + origin); + VectorBase::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + // SubVector(const PackedMatrix &M) { + // VectorBase::data_ = const_cast (M.Data()); + // VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + //} + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase() { + // this copy constructor needed for Range() to work in base class. + VectorBase::data_ = other.data_; + VectorBase::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase() { + VectorBase::data_ = const_cast(data); + VectorBase::dim_ = length; + } + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase &matrix, MatrixIndexT row) { + VectorBase::data_ = const_cast(matrix.RowData(row)); + VectorBase::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector &operator=(const SubVector &other) {} }; /// @} end of "addtogroup matrix_group" @@ -303,43 +312,41 @@ class SubVector : public VectorBase { /// @{ /// Output to a C++ stream. Non-binary by default (use Write for /// binary output). -template -std::ostream & operator << (std::ostream & out, const VectorBase & v); +template +std::ostream &operator<<(std::ostream &out, const VectorBase &v); /// Input from a C++ stream. Will automatically read text or /// binary data from the stream. -template -std::istream & operator >> (std::istream & in, VectorBase & v); +template +std::istream &operator>>(std::istream &in, VectorBase &v); /// Input from a C++ stream. Will automatically read text or /// binary data from the stream. -template -std::istream & operator >> (std::istream & in, Vector & v); +template +std::istream &operator>>(std::istream &in, Vector &v); /// @} end of \addtogroup matrix_funcs_io /// \addtogroup matrix_funcs_scalar /// @{ -//template -//bool ApproxEqual(const VectorBase &a, - //const VectorBase &b, Real tol = 0.01) { - //return a.ApproxEqual(b, tol); +// template +// bool ApproxEqual(const VectorBase &a, +// const VectorBase &b, Real tol = 0.01) { +// return a.ApproxEqual(b, tol); //} -//template -//inline void AssertEqual(VectorBase &a, VectorBase &b, - //float tol = 0.01) { - //KALDI_ASSERT(a.ApproxEqual(b, tol)); +// template +// inline void AssertEqual(VectorBase &a, VectorBase &b, +// float tol = 0.01) { +// KALDI_ASSERT(a.ApproxEqual(b, tol)); //} - } // namespace kaldi // we need to include the implementation #include "matrix/kaldi-vector-inl.h" - #endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/runtime/engine/common/matrix/matrix-common.h b/runtime/engine/common/matrix/matrix-common.h index b7bdbbc8..512beb20 100644 --- a/runtime/engine/common/matrix/matrix-common.h +++ b/runtime/engine/common/matrix/matrix-common.h @@ -27,52 +27,58 @@ namespace kaldi { // this enums equal to CblasTrans and CblasNoTrans constants from CBLAS library -// we are writing them as literals because we don't want to include here matrix/kaldi-blas.h, -// which puts many symbols into global scope (like "real") via the header f2c.h +// we are writing them as literals because we don't want to include here +// matrix/kaldi-blas.h, +// which puts many symbols into global scope (like "real") via the header f2c.h typedef enum { - kTrans = 112, // = CblasTrans - kNoTrans = 111 // = CblasNoTrans + kTrans = 112, // = CblasTrans + kNoTrans = 111 // = CblasNoTrans } MatrixTransposeType; -typedef enum { - kSetZero, - kUndefined, - kCopyData -} MatrixResizeType; +typedef enum { kSetZero, kUndefined, kCopyData } MatrixResizeType; typedef enum { - kDefaultStride, - kStrideEqualNumCols, + kDefaultStride, + kStrideEqualNumCols, } MatrixStrideType; typedef enum { - kTakeLower, - kTakeUpper, - kTakeMean, - kTakeMeanAndCheck + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck } SpCopyType; -template class VectorBase; -template class Vector; -template class SubVector; -template class MatrixBase; -template class SubMatrix; -template class Matrix; +template +class VectorBase; +template +class Vector; +template +class SubVector; +template +class MatrixBase; +template +class SubMatrix; +template +class Matrix; /// This class provides a way for switching between double and float types. -template class OtherReal { }; // useful in reading+writing routines - // to switch double and float. +template +class OtherReal {}; // useful in reading+writing routines + // to switch double and float. /// A specialized class for switching from float to double. -template<> class OtherReal { - public: - typedef double Real; +template <> +class OtherReal { + public: + typedef double Real; }; /// A specialized class for switching from double to float. -template<> class OtherReal { - public: - typedef float Real; +template <> +class OtherReal { + public: + typedef float Real; }; @@ -81,12 +87,10 @@ typedef int32 SignedMatrixIndexT; typedef uint32 UnsignedMatrixIndexT; // If you want to use size_t for the index type, do as follows instead: -//typedef size_t MatrixIndexT; -//typedef ssize_t SignedMatrixIndexT; -//typedef size_t UnsignedMatrixIndexT; - +// typedef size_t MatrixIndexT; +// typedef ssize_t SignedMatrixIndexT; +// typedef size_t UnsignedMatrixIndexT; } - #endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/runtime/engine/kaldi/CMakeLists.txt b/runtime/engine/kaldi/CMakeLists.txt index f9b42e06..e55cecbb 100644 --- a/runtime/engine/kaldi/CMakeLists.txt +++ b/runtime/engine/kaldi/CMakeLists.txt @@ -1,14 +1,15 @@ -project(kaldi) include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ) add_subdirectory(base) add_subdirectory(util) -add_subdirectory(lat) -add_subdirectory(fstext) -add_subdirectory(decoder) -add_subdirectory(lm) +if(WITH_ASR) + add_subdirectory(lat) + add_subdirectory(fstext) + add_subdirectory(decoder) + add_subdirectory(lm) -add_subdirectory(fstbin) -add_subdirectory(lmbin) + add_subdirectory(fstbin) + add_subdirectory(lmbin) +endif() diff --git a/runtime/engine/kaldi/base/kaldi-types.h b/runtime/engine/kaldi/base/kaldi-types.h index c6a3e1ae..f371e3da 100644 --- a/runtime/engine/kaldi/base/kaldi-types.h +++ b/runtime/engine/kaldi/base/kaldi-types.h @@ -44,7 +44,19 @@ typedef float BaseFloat; #ifndef COMPILE_WITHOUT_OPENFST +#ifdef WITH_ASR #include +#else +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; + +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; +#endif namespace kaldi { using ::int16; diff --git a/runtime/engine/vad/CMakeLists.txt b/runtime/engine/vad/CMakeLists.txt new file mode 100644 index 00000000..d13cc407 --- /dev/null +++ b/runtime/engine/vad/CMakeLists.txt @@ -0,0 +1,18 @@ +# set(CMAKE_CXX_STANDARD 11) + +# # 指定下载解压后的fastdeploy库路径 +# set(FASTDEPLOY_INSTALL_DIR "fdlib/fastdeploy-linux-x64-1.0.4" CACHE STRING force) + +# if(NOT EXISTS ${FASTDEPLOY_INSTALL_DIR}) +# message(FATAL_ERROR "Please using cmake -B build -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR}") +# endif() + +# include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# # 添加FastDeploy依赖头文件 +# include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_onnx_silero_vad ${CMAKE_CURRENT_SOURCE_DIR}/infer_onnx_silero_vad.cc wav.h vad.cc vad.h) + +# 添加FastDeploy库依赖 +target_link_libraries(infer_onnx_silero_vad ${FASTDEPLOY_LIBS}) diff --git a/runtime/engine/vad/README.md b/runtime/engine/vad/README.md new file mode 100644 index 00000000..f032be86 --- /dev/null +++ b/runtime/engine/vad/README.md @@ -0,0 +1,121 @@ +English | [简体中文](README_CN.md) + +# Silero VAD Deployment Example + +This directory provides examples that `infer_onnx_silero_vad` fast finishes the deployment of VAD models on CPU/GPU. + +Before deployment, two steps require confirmation. + +- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/en/build_and_install/download_prebuilt_libraries.md). +- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../docs/en/build_and_install/download_prebuilt_libraries.md). + +Taking VAD inference on Linux as an example, the compilation test can be completed by executing the following command in this directory. + +```bash +mkdir build +cd build +# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz +tar xvf fastdeploy-linux-x64-x.x.x.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x +make -j + +# Download the VAD model file and test audio. After decompression, place the model and test audio in the infer_onnx_silero_vad.cc peer directory +wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz +wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad_sample.wav + +# inference +./infer_onnx_silero_vad ../silero_vad.onnx ../silero_vad_sample.wav +``` + +- The above command works for Linux or MacOS. Refer to: + - [How to use FastDeploy C++ SDK in Windows](../../../../docs/en/faq/use_sdk_on_windows.md) for SDK use-pattern in Windows + +## VAD C++ Interface + +### Vad Class + +```c++ +Vad::Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption()) +``` + +**Parameter** + +> * **model_file**(str): Model file path +> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default. (use the default configuration) + +### setAudioCofig function + +**Must be called before the `init` function** + +```c++ +void Vad::setAudioCofig(int sr, int frame_ms, float threshold, int min_silence_duration_ms, int speech_pad_ms); +``` + +**Parameter** + +> * **sr**(int): sampling rate +> * **frame_ms**(int): The length of each detection frame, and it is used to calculate the detection window size +> * **threshold**(float): Result probability judgment threshold +> * **min_silence_duration_ms**(int): The threshold used to calculate whether it is silence +> * **speech_pad_ms**(int): Used to calculate the end time of the speech + +### init function + +Used to initialize audio-related parameters. + +```c++ +void Vad::init(); +``` + +### loadAudio function + +Load audio. + +```c++ +void Vad::loadAudio(const std::string& wavPath) +``` + +**Parameter** + +> * **wavPath**(str): Audio file path + +### Predict function + +Used to start model reasoning. + +```c++ +bool Vad::Predict(); +``` + +### getResult function + +**Used to obtain reasoning results** + +```c++ +std::vector> Vad::getResult( + float removeThreshold = 1.6, float expandHeadThreshold = 0.32, float expandTailThreshold = 0, + float mergeThreshold = 0.3); +``` + +**Parameter** + +> * **removeThreshold**(float): Discard result fragment threshold; If some recognition results are too short, they will be discarded according to this threshold +> * **expandHeadThreshold**(float): Offset at the beginning of the segment; The recognized start time may be too close to the voice part, so move forward the start time accordingly +> * **expandTailThreshold**(float): Offset at the end of the segment; The recognized end time may be too close to the voice part, so the end time is moved back accordingly +> * **mergeThreshold**(float): Some result segments are very close and can be combined into one, and the vocal segments can be combined accordingly + +**The output result format is**`std::vector>` + +> Output a list, each element is a speech fragment +> +> Each clip can use 'start' to get the start time and 'end' to get the end time + +### Tips + +1. `The setAudioCofig`function must be called before the `init` function +2. The sampling rate of the input audio file must be consistent with that set in the code + +- [Model Description](../) +- [How to switch the model inference backend engine](../../../../docs/en/faq/how_to_change_backend.md) diff --git a/runtime/engine/vad/README_CN.md b/runtime/engine/vad/README_CN.md new file mode 100644 index 00000000..c45d9896 --- /dev/null +++ b/runtime/engine/vad/README_CN.md @@ -0,0 +1,119 @@ +[English](README.md) | 简体中文 +# Silero VAD 部署示例 + +本目录下提供`infer_onnx_silero_vad`快速完成 Silero VAD 模型在CPU/GPU。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上 VAD 推理为例,在本目录执行如下命令即可完成编译测试。 + +```bash +mkdir build +cd build +# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用 +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz +tar xvf fastdeploy-linux-x64-x.x.x.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x +make -j + +# 下载 VAD 模型文件和测试音频,解压后将模型和测试音频放置在与 infer_onnx_silero_vad.cc 同级目录下 +wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz +wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad_sample.wav + +# 推理 +./infer_onnx_silero_vad ../silero_vad.onnx ../silero_vad_sample.wav +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../docs/cn/faq/use_sdk_on_windows.md) + +## VAD C++ 接口 +### Vad 类 + +```c++ +Vad::Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption()) +``` + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 + +### setAudioCofig 函数 + +**必须在`init`函数前调用** + +```c++ +void Vad::setAudioCofig(int sr, int frame_ms, float threshold, int min_silence_duration_ms, int speech_pad_ms); +``` + +**参数** + +> * **sr**(int): 采样率 +> * **frame_ms**(int): 每次检测帧长,用于计算检测窗口大小 +> * **threshold**(float): 结果概率判断阈值 +> * **min_silence_duration_ms**(int): 用于计算判断是否是 silence 的阈值 +> * **speech_pad_ms**(int): 用于计算 speach 结束时刻 + +### init 函数 + +用于初始化音频相关参数 + +```c++ +void Vad::init(); +``` + +### loadAudio 函数 + +加载音频 + +```c++ +void Vad::loadAudio(const std::string& wavPath) +``` + +**参数** + +> * **wavPath**(str): 音频文件路径 + +### Predict 函数 + +用于开始模型推理 + +```c++ +bool Vad::Predict(); +``` + +### getResult 函数 + +**用于获取推理结果** + +```c++ +std::vector> Vad::getResult( + float removeThreshold = 1.6, float expandHeadThreshold = 0.32, float expandTailThreshold = 0, + float mergeThreshold = 0.3); +``` + +**参数** + +> * **removeThreshold**(float): 丢弃结果片段阈值;部分识别结果太短则根据此阈值丢弃 +> * **expandHeadThreshold**(float): 结果片段开始时刻偏移;识别到的开始时刻可能过于贴近发声部分,因此据此前移开始时刻 +> * **expandTailThreshold**(float): 结果片段结束时刻偏移;识别到的结束时刻可能过于贴近发声部分,因此据此后移结束时刻 +> * **mergeThreshold**(float): 有的结果片段十分靠近,可以合并成一个,据此合并发声片段 + +**输出结果格式为**`std::vector>` + +> 输出一个列表,每个元素是一个讲话片段 +> +> 每个片段可以用 'start' 获取到开始时刻,用 'end' 获取到结束时刻 + +### 提示 + +1. `setAudioCofig`函数必须在`init`函数前调用 +2. 输入的音频文件的采样率必须与代码中设置的保持一致 + +- [模型介绍](../) +- [如何切换模型推理后端引擎](../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/runtime/engine/vad/infer_onnx_silero_vad.cc b/runtime/engine/vad/infer_onnx_silero_vad.cc new file mode 100644 index 00000000..7fb52406 --- /dev/null +++ b/runtime/engine/vad/infer_onnx_silero_vad.cc @@ -0,0 +1,65 @@ + +#include "vad.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: infer_onnx_silero_vad path/to/model path/to/audio " + "run_option, " + "e.g ./infer_onnx_silero_vad silero_vad.onnx sample.wav" + << std::endl; + return -1; + } + + std::string model_file = argv[1]; + std::string audio_file = argv[2]; + + int sr = 16000; + Vad vad(model_file); + // custom config, but must be set before init + vad.SetConfig(sr, 32, 0.45f, 200, 0, 0); + vad.Init(); + + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(audio_file); + assert(wav_reader.sample_rate() == sr); + + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + int window_size_samples = vad.WindowSizeSamples(); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + auto current_chunk_size = end - start; + + std::vector r{&inputWav[0] + start, &inputWav[0] + end}; + assert(r.size() == current_chunk_size); + + if (!vad.ForwardChunk(r)) { + std::cerr << "Failed to inference while using model:" + << vad.ModelName() << "." << std::endl; + return false; + } + + Vad::State s = vad.Postprocess(); + std::cout << s << " "; + } + std::cout << std::endl; + + std::vector> result = vad.GetResult(); + for (auto& res : result) { + std::cout << "speak start: " << res["start"] + << " s, end: " << res["end"] << " s | "; + } + std::cout << "\b\b " << std::endl; + + vad.Reset(); + + return 0; +} diff --git a/runtime/engine/vad/vad.cc b/runtime/engine/vad/vad.cc new file mode 100644 index 00000000..7630b98d --- /dev/null +++ b/runtime/engine/vad/vad.cc @@ -0,0 +1,306 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "vad.h" +#include +#include + + +#ifdef NDEBUG +#define LOG_DEBUG \ + ::fastdeploy::FDLogger(true, "[DEBUG]") << __REL_FILE__ << "(" << __LINE__ \ + << ")::" << __FUNCTION__ << "\t" +#else +#define LOG_DEBUG \ + ::fastdeploy::FDLogger(false, "[DEBUG]") \ + << __REL_FILE__ << "(" << __LINE__ << ")::" << __FUNCTION__ << "\t" +#endif + +Vad::Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& + custom_option /* = fastdeploy::RuntimeOption() */) { + valid_cpu_backends = {fastdeploy::Backend::ORT, + fastdeploy::Backend::OPENVINO}; + valid_gpu_backends = {fastdeploy::Backend::ORT, fastdeploy::Backend::TRT}; + + runtime_option = custom_option; + // ORT backend + runtime_option.UseCpu(); + runtime_option.UseOrtBackend(); + runtime_option.model_format = fastdeploy::ModelFormat::ONNX; + // grap opt level + runtime_option.ort_option.graph_optimization_level = 99; + // one-thread + runtime_option.ort_option.intra_op_num_threads = 1; + runtime_option.ort_option.inter_op_num_threads = 1; + // model path + runtime_option.model_file = model_file; +} + +void Vad::Init() { + std::call_once(init_, [&]() { initialized = Initialize(); }); +} + +std::string Vad::ModelName() const { return "VAD"; } + +void Vad::SetConfig(int sr, + int frame_ms, + float threshold, + int min_silence_duration_ms, + int speech_pad_left_ms, + int speech_pad_right_ms) { + if (initialized) { + fastdeploy::FDERROR << "SetConfig must be called before init" + << std::endl; + throw std::runtime_error("SetConfig must be called before init"); + } + sample_rate_ = sr; + sr_per_ms_ = sr / 1000; + threshold_ = threshold; + frame_ms_ = frame_ms; + min_silence_samples_ = min_silence_duration_ms * sr_per_ms_; + speech_pad_left_samples_ = speech_pad_left_ms * sr_per_ms_; + speech_pad_right_samples_ = speech_pad_right_ms * sr_per_ms_; + + // init chunk size + window_size_samples_ = frame_ms * sr_per_ms_; + current_chunk_size_ = window_size_samples_; + + fastdeploy::FDINFO << "sr=" << sr << " threshold=" << threshold + << " frame_ms=" << frame_ms + << " min_silence_duration_ms=" << min_silence_duration_ms + << " speech_pad_left_ms=" << speech_pad_left_ms + << " speech_pad_right_ms=" << speech_pad_right_ms; +} + +void Vad::Reset() { + std::memset(h_.data(), 0.0f, h_.size() * sizeof(float)); + std::memset(c_.data(), 0.0f, c_.size() * sizeof(float)); + + triggerd_ = false; + temp_end_ = 0; + current_sample_ = 0; + + speakStart_.clear(); + speakEnd_.clear(); + + states_.clear(); +} + +bool Vad::Initialize() { + // input & output holder + inputTensors_.resize(4); + outputTensors_.resize(3); + + // input shape + input_node_dims_.emplace_back(1); + input_node_dims_.emplace_back(window_size_samples_); + // sr buffer + sr_.resize(1); + sr_[0] = sample_rate_; + // hidden state buffer + h_.resize(size_hc_); + c_.resize(size_hc_); + + Reset(); + + // InitRuntime + if (!InitRuntime()) { + fastdeploy::FDERROR << "Failed to initialize fastdeploy backend." + << std::endl; + return false; + } + fastdeploy::FDINFO << "init done."; + return true; +} + +bool Vad::ForwardChunk(std::vector& chunk) { + // last chunk may not be window_size_samples_ + input_node_dims_.back() = chunk.size(); + assert(window_size_samples_ >= chunk.size()); + current_chunk_size_ = chunk.size(); + + inputTensors_[0].name = "input"; + inputTensors_[0].SetExternalData( + input_node_dims_, fastdeploy::FDDataType::FP32, chunk.data()); + inputTensors_[1].name = "sr"; + inputTensors_[1].SetExternalData( + sr_node_dims_, fastdeploy::FDDataType::INT64, sr_.data()); + inputTensors_[2].name = "h"; + inputTensors_[2].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, h_.data()); + inputTensors_[3].name = "c"; + inputTensors_[3].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, c_.data()); + + if (!Infer(inputTensors_, &outputTensors_)) { + return false; + } + + // Push forward sample index + current_sample_ += current_chunk_size_; + return true; +} + +const Vad::State& Vad::Postprocess() { + // update prob, h, c + outputProb_ = *(float*)outputTensors_[0].Data(); + auto* hn = static_cast(outputTensors_[1].MutableData()); + std::memcpy(h_.data(), hn, h_.size() * sizeof(float)); + auto* cn = static_cast(outputTensors_[2].MutableData()); + std::memcpy(c_.data(), cn, c_.size() * sizeof(float)); + + if (outputProb_ < threshold_ && !triggerd_) { + // 1. Silence + LOG_DEBUG << "{ silence: " << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; + states_.emplace_back(Vad::State::SIL); + } else if (outputProb_ >= threshold_ && !triggerd_) { + // 2. Start + triggerd_ = true; + speech_start_ = + current_sample_ - current_chunk_size_ - speech_pad_left_samples_; + float start_sec = 1.0 * speech_start_ / sample_rate_; + speakStart_.emplace_back(start_sec); + LOG_DEBUG << "{ speech start: " << start_sec + << " s; prob: " << outputProb_ << " }"; + states_.emplace_back(Vad::State::START); + } else if (outputProb_ >= threshold_ - 0.15 && triggerd_) { + // 3. Continue + + if (temp_end_ != 0) { + // speech prob relaxation, speech continues again + LOG_DEBUG << "{ speech fake end(sil < min_silence_ms) to continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; + temp_end_ = 0; + } else { + // speech prob relaxation, keep tracking speech + LOG_DEBUG << "{ speech continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; + } + + states_.emplace_back(Vad::State::SPEECH); + } else if (outputProb_ < threshold_ - 0.15 && triggerd_) { + // 4. End + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + // check possible speech end + if (current_sample_ - temp_end_ < min_silence_samples_) { + // a. silence < min_slience_samples, continue speaking + LOG_DEBUG << "{ speech fake end(sil < min_silence_ms): " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; + states_.emplace_back(Vad::State::SIL); + } else { + // b. silence >= min_slience_samples, end speaking + speech_end_ = current_sample_ + speech_pad_right_samples_; + temp_end_ = 0; + triggerd_ = false; + auto end_sec = 1.0 * speech_end_ / sample_rate_; + speakEnd_.emplace_back(end_sec); + LOG_DEBUG << "{ speech end: " << end_sec + << " s; prob: " << outputProb_ << " }"; + states_.emplace_back(Vad::State::END); + } + } + + return states_.back(); +} + +const std::vector> Vad::GetResult( + float removeThreshold, + float expandHeadThreshold, + float expandTailThreshold, + float mergeThreshold) const { + float audioLength = 1.0 * current_sample_ / sample_rate_; + if (speakStart_.empty() && speakEnd_.empty()) { + return {}; + } + if (speakEnd_.size() != speakStart_.size()) { + // set the audio length as the last end + speakEnd_.emplace_back(audioLength); + } + // Remove too short segments + // auto startIter = speakStart_.begin(); + // auto endIter = speakEnd_.begin(); + // while (startIter != speakStart_.end()) { + // if (removeThreshold < audioLength && + // *endIter - *startIter < removeThreshold) { + // startIter = speakStart_.erase(startIter); + // endIter = speakEnd_.erase(endIter); + // } else { + // startIter++; + // endIter++; + // } + // } + // // Expand to avoid to tight cut. + // startIter = speakStart_.begin(); + // endIter = speakEnd_.begin(); + // *startIter = std::fmax(0.f, *startIter - expandHeadThreshold); + // *endIter = std::fmin(*endIter + expandTailThreshold, *(startIter + 1)); + // endIter = speakEnd_.end() - 1; + // startIter = speakStart_.end() - 1; + // *startIter = fmax(*startIter - expandHeadThreshold, *(endIter - 1)); + // *endIter = std::fmin(*endIter + expandTailThreshold, audioLength); + // for (int i = 1; i < speakStart_.size() - 1; ++i) { + // speakStart_[i] = std::fmax(speakStart_[i] - expandHeadThreshold, + // speakEnd_[i - 1]); + // speakEnd_[i] = std::fmin(speakEnd_[i] + expandTailThreshold, + // speakStart_[i + 1]); + // } + // // Merge very closed segments + // startIter = speakStart_.begin() + 1; + // endIter = speakEnd_.begin(); + // while (startIter != speakStart_.end()) { + // if (*startIter - *endIter < mergeThreshold) { + // startIter = speakStart_.erase(startIter); + // endIter = speakEnd_.erase(endIter); + // } else { + // startIter++; + // endIter++; + // } + // } + + std::vector> result; + for (int i = 0; i < speakStart_.size(); ++i) { + result.emplace_back(std::map( + {{"start", speakStart_[i]}, {"end", speakEnd_[i]}})); + } + return result; +} + +std::ostream& operator<<(std::ostream& os, const Vad::State& s) { + switch (s) { + case Vad::State::SIL: + os << "[SIL]"; + break; + case Vad::State::START: + os << "[STA]"; + break; + case Vad::State::SPEECH: + os << "[SPE]"; + break; + case Vad::State::END: + os << "[END]"; + break; + default: + // illegal state + os << "[ILL]"; + break; + } + return os; +} \ No newline at end of file diff --git a/runtime/engine/vad/vad.h b/runtime/engine/vad/vad.h new file mode 100644 index 00000000..6eed7d1c --- /dev/null +++ b/runtime/engine/vad/vad.h @@ -0,0 +1,124 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include "./wav.h" +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/runtime.h" + +class Vad : public fastdeploy::FastDeployModel { + public: + enum class State { SIL = 0, START, SPEECH, END }; + friend std::ostream& operator<<(std::ostream& os, const Vad::State& s); + + Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& custom_option = + fastdeploy::RuntimeOption()); + + void Init(); + + void Reset(); + + void SetConfig(int sr, + int frame_ms, + float threshold, + int min_silence_duration_ms, + int speech_pad_left_ms, + int speech_pad_right_ms); + + bool ForwardChunk(std::vector& chunk); + + const State& Postprocess(); + + const std::vector> GetResult( + float removeThreshold = 0.0, + float expandHeadThreshold = 0.0, + float expandTailThreshold = 0, + float mergeThreshold = 0.0) const; + + const std::vector GetStates() const { return states_; } + + int SampleRate() const { return sample_rate_; } + + int FrameMs() const { return frame_ms_; } + int64_t WindowSizeSamples() const { return window_size_samples_; } + + float Threshold() const { return threshold_; } + + int MinSilenceDurationMs() const { + return min_silence_samples_ / sample_rate_; + } + int SpeechPadLeftMs() const { + return speech_pad_left_samples_ / sample_rate_; + } + int SpeechPadRightMs() const { + return speech_pad_right_samples_ / sample_rate_; + } + + int MinSilenceSamples() const { return min_silence_samples_; } + int SpeechPadLeftSamples() const { return speech_pad_left_samples_; } + int SpeechPadRightSamples() const { return speech_pad_right_samples_; } + + std::string ModelName() const override; + + private: + bool Initialize(); + + private: + std::once_flag init_; + // input and output + std::vector inputTensors_; + std::vector outputTensors_; + + // model states + bool triggerd_ = false; + unsigned int speech_start_ = 0; + unsigned int speech_end_ = 0; + unsigned int temp_end_ = 0; + unsigned int current_sample_ = 0; + unsigned int current_chunk_size_ = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + float outputProb_; + + std::vector speakStart_; + mutable std::vector speakEnd_; + + std::vector states_; + + /* ======================================================================== + */ + int sample_rate_ = 16000; + int frame_ms_ = 32; // 32, 64, 96 for 16k + float threshold_ = 0.5f; + + int64_t window_size_samples_; // support 256 512 768 for 8k; 512 1024 1536 + // for 16k. + int sr_per_ms_; // support 8 or 16 + int min_silence_samples_; // sr_per_ms_ * frame_ms_ + int speech_pad_left_samples_{0}; // usually 250ms + int speech_pad_right_samples_{0}; // usually 0 + + /* ======================================================================== + */ + std::vector sr_; + const size_t size_hc_ = 2 * 1 * 64; // It's FIXED. + std::vector h_; + std::vector c_; + + std::vector input_node_dims_; + const std::vector sr_node_dims_ = {1}; + const std::vector hc_node_dims_ = {2, 1, 64}; +}; diff --git a/runtime/engine/vad/wav.h b/runtime/engine/vad/wav.h new file mode 100644 index 00000000..6d1a6f72 --- /dev/null +++ b/runtime/engine/vad/wav.h @@ -0,0 +1,197 @@ +// Copyright (c) 2016 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +namespace wav { + +struct WavHeader { + char riff[4]; // "riff" + unsigned int size; + char wav[4]; // "WAVE" + char fmt[4]; // "fmt " + unsigned int fmt_size; + uint16_t format; + uint16_t channels; + unsigned int sample_rate; + unsigned int bytes_per_second; + uint16_t block_size; + uint16_t bit; + char data[4]; // "data" + unsigned int data_size; +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (NULL == fp) { + std::cout << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "riff" "WAVE" "fmt " "data" + + // Skip any sub-chunks between "fmt" and "data". Usually there will + // be a single "fact" sub chunk, but on Windows there can also be a + // "list" sub chunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next sub chunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; // Create 1-dim array + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + // std::cout << sample; + data_[i] = static_cast(sample); + // std::cout << data_[i]; + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, + int num_samples, + int num_channel, + int sample_rate, + int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + // init char 'riff' 'WAVE' 'fmt ' 'data' + WavHeader header; + char wav_header[44] = { + 0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, + 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; + memcpy(&header, wav_header, sizeof(header)); + header.channels = num_channel_; + header.bit = bits_per_sample_; + header.sample_rate = sample_rate_; + header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); + header.size = sizeof(header) - 8 + header.data_size; + header.bytes_per_second = + sample_rate_ * num_channel_ * (bits_per_sample_ / 8); + header.block_size = num_channel_ * (bits_per_sample_ / 8); + + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +} // namespace wav From e9da7e0e07894e4607514ec41ec290e100f78ee2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sun, 5 Mar 2023 15:12:35 +0800 Subject: [PATCH 14/50] [runtime] add logging module, build on linux and android, normalize option name (#2986) * rename option with WITH_; add logging module to replace with glog * refactor logging module; build pass on linux and andorid --- runtime/.gitignore | 1 + runtime/CMakeLists.txt | 33 ++-- runtime/build_android.sh | 36 ++++ runtime/cmake/fastdeploy.cmake | 30 ++-- runtime/cmake/glog.cmake | 27 ++- runtime/cmake/gtest.cmake | 26 ++- runtime/cmake/summary.cmake | 58 +++++++ .../decoder/ctc_prefix_beam_search_decoder.cc | 4 +- runtime/engine/asr/nnet/u2_nnet.cc | 46 +++--- runtime/engine/cls/CMakeLists.txt | 6 - runtime/engine/cls/nnet/CMakeLists.txt | 9 +- runtime/engine/cls/nnet/panns_nnet.cc | 20 +-- runtime/engine/codelab/CMakeLists.txt | 5 +- runtime/engine/codelab/glog/CMakeLists.txt | 4 +- runtime/engine/common/CMakeLists.txt | 5 +- runtime/engine/common/base/CMakeLists.txt | 28 +++- runtime/engine/common/base/common.h | 2 +- runtime/engine/common/base/glog_utils.cc | 12 ++ runtime/engine/common/base/glog_utils.h | 9 + runtime/engine/common/base/log_impl.cc | 105 ++++++++++++ runtime/engine/common/base/log_impl.h | 156 ++++++++++++++++++ runtime/engine/common/base/macros.h | 4 +- runtime/engine/common/frontend/CMakeLists.txt | 4 +- runtime/engine/common/matrix/kaldi-vector.cc | 6 +- runtime/engine/common/utils/CMakeLists.txt | 25 ++- runtime/engine/vad/CMakeLists.txt | 19 +-- ..._onnx_silero_vad.cc => silero_vad_main.cc} | 0 runtime/examples/silero_vad/.gitignore | 1 + .../vad => examples/silero_vad}/README.md | 0 .../vad => examples/silero_vad}/README_CN.md | 0 runtime/examples/silero_vad/local/build.sh | 14 ++ .../silero_vad/local/build_android.sh | 31 ++++ runtime/examples/silero_vad/local/decode.sh | 0 runtime/examples/silero_vad/local/download.sh | 10 ++ runtime/examples/silero_vad/path.sh | 18 ++ runtime/examples/silero_vad/run.sh | 38 +++++ runtime/examples/silero_vad/utils | 1 + 37 files changed, 660 insertions(+), 133 deletions(-) create mode 100755 runtime/build_android.sh create mode 100644 runtime/cmake/summary.cmake create mode 100644 runtime/engine/common/base/glog_utils.cc create mode 100644 runtime/engine/common/base/glog_utils.h create mode 100644 runtime/engine/common/base/log_impl.cc create mode 100644 runtime/engine/common/base/log_impl.h rename runtime/engine/vad/{infer_onnx_silero_vad.cc => silero_vad_main.cc} (100%) create mode 100644 runtime/examples/silero_vad/.gitignore rename runtime/{engine/vad => examples/silero_vad}/README.md (100%) rename runtime/{engine/vad => examples/silero_vad}/README_CN.md (100%) create mode 100755 runtime/examples/silero_vad/local/build.sh create mode 100755 runtime/examples/silero_vad/local/build_android.sh create mode 100755 runtime/examples/silero_vad/local/decode.sh create mode 100755 runtime/examples/silero_vad/local/download.sh create mode 100644 runtime/examples/silero_vad/path.sh create mode 100644 runtime/examples/silero_vad/run.sh create mode 120000 runtime/examples/silero_vad/utils diff --git a/runtime/.gitignore b/runtime/.gitignore index 9aa98ef7..a654dae4 100644 --- a/runtime/.gitignore +++ b/runtime/.gitignore @@ -4,3 +4,4 @@ engine/common/base/log.h tools/valgrind* *log fc_patch/* +test diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index 015a1088..bdce2046 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -16,11 +16,14 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() -project(paddlespeech VERSION 0.1) - -set(CMAKE_VERBOSE_MAKEFILE on) +project(paddlespeech VERSION 0.1) +# if(ANDROID) +# # when cross compile with ndk under linux, +# # UNIX and ANROID are all True +# set(UNIX) +# endif() include(FetchContent) include(ExternalProject) @@ -30,12 +33,12 @@ set(FETCHCONTENT_QUIET off) get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") set(FETCHCONTENT_BASE_DIR ${fc_patch}) -set(CMAKE_CXX_FLAGS) -set(CMAKE_CXX_FLAGS_DEBUG) -set(CMAKE_CXX_FLAGS_RELEASE) +set(CMAKE_VERBOSE_MAKEFILE ON) +set(PPS_CXX_STANDARD 14) # set std-14 -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) +add_compile_options(-fPIC) # compiler option # Keep the same with openfst, -fPIC or -fpic @@ -43,8 +46,6 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") - -add_compile_options(-fPIC) ############################################################################### # Option Configurations ############################################################################### @@ -52,11 +53,10 @@ option(WITH_ASR "build asr" ON) option(WITH_CLS "build cls" ON) option(WITH_VAD "build vad" ON) -option(TEST_DEBUG "option for debug" OFF) -option(USE_PROFILING "enable c++ profling" OFF) -option(WITH_TESTING "unit test" ON) +option(WITH_GPU "NNet using GPU." OFF) -option(USING_GPU "u2 compute on GPU." OFF) +option(WITH_PROFILING "enable c++ profling" OFF) +option(WITH_TESTING "unit test" ON) ############################################################################### # Include Third Party @@ -70,7 +70,6 @@ if(WITH_TESTING) include(gtest) # download, build, install gtest endif() - # fastdeploy include(fastdeploy) @@ -161,15 +160,11 @@ if(WITH_ASR) message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) endif() +include(summary) ############################################################################### # Add local library ############################################################################### set(ENGINE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/engine) -message(STATUS "CMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}") -message(STATUS "CMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}") -message(STATUS "CMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}") - - add_subdirectory(engine) diff --git a/runtime/build_android.sh b/runtime/build_android.sh new file mode 100755 index 00000000..64a33762 --- /dev/null +++ b/runtime/build_android.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -ex + +ANDROID_NDK=/workspace/zhanghui/android-sdk/android-ndk-r25c + +# Setting up Android toolchanin +ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a' +ANDROID_PLATFORM="android-21" # API >= 21 +ANDROID_STL=c++_shared # 'c++_shared', 'c++_static' +ANDROID_TOOLCHAIN=clang # 'clang' only +TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + +# Create build directory +BUILD_ROOT=build/Android +BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21 +#FASDEPLOY_INSTALL_DIR="${BUILD_DIR}/install" +#mkdir build && mkdir ${BUILD_ROOT} && mkdir ${BUILD_DIR} +mkdir -p ${BUILD_DIR} +cd ${BUILD_DIR} + +# CMake configuration with Android toolchain +cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DANDROID_ABI=${ANDROID_ABI} \ + -DANDROID_NDK=${ANDROID_NDK} \ + -DANDROID_PLATFORM=${ANDROID_PLATFORM} \ + -DANDROID_STL=${ANDROID_STL} \ + -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -Wno-dev ../../.. + #-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + +# Build FastDeploy Android C++ SDK +make diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake index cb9ceacd..463a8e8e 100644 --- a/runtime/cmake/fastdeploy.cmake +++ b/runtime/cmake/fastdeploy.cmake @@ -1,44 +1,42 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) - set(ARCH "mserver_x86_64" CACHE STRING "Target Architecture: android_arm, android_armv7, android_armv8, android_x86, android_x86_64, mserver_x86_64, ubuntu_x86_64, ios_armv7, ios_armv7s, ios_armv8, ios_x86_64, ios_x86, windows_x86") -set(CMAKE_VERBOSE_MAKEFILE ON) - set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy) if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz) exec_program("mkdir -p ${FASTDEPLOY_DIR} && - wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz -P ${FASTDEPLOY_DIR} && + wget -c https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz -P ${FASTDEPLOY_DIR} && tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz -C ${FASTDEPLOY_DIR} && mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4 ${FASTDEPLOY_DIR}/linux-x64") endif() -if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz) +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared.tgz) exec_program("mkdir -p ${FASTDEPLOY_DIR} && - wget https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.0-shared.tgz -P ${FASTDEPLOY_DIR} && - tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz -C ${FASTDEPLOY_DIR} && - mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared ${FASTDEPLOY_DIR}/android-armv7v8") + wget -c https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared ${FASTDEPLOY_DIR}/android-armv7v8") endif() -if (ARCH STREQUAL "mserver_x86_64") + +if(ANDROID) + set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") +elseif(UNIX) set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/linux-x64) add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") # add_definitions("-DUSE_ORT_BACKEND") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") -elseif (ARCH STREQUAL "android_armv7") - set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8) - add_definitions("-DUSE_PADDLE_LITE_BAKEND") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") endif() +message(STATUS "FASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} ${UNIX}") include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) # fix compiler flags conflict, since fastdeploy using c++11 for project -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) include_directories(${FASTDEPLOY_INCS}) message(STATUS "FASTDEPLOY_INCS=${FASTDEPLOY_INCS}") \ No newline at end of file diff --git a/runtime/cmake/glog.cmake b/runtime/cmake/glog.cmake index 8cc9999b..cbb97d2d 100644 --- a/runtime/cmake/glog.cmake +++ b/runtime/cmake/glog.cmake @@ -1,8 +1,21 @@ include(FetchContent) -FetchContent_Declare( - glog - URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip - URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc -) -FetchContent_MakeAvailable(glog) -include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) + +if(ANDROID) +else() # UNIX + add_definitions(-DWITH_GLOG) + FetchContent_Declare( + glog + URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip + URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc + ) + FetchContent_MakeAvailable(glog) + include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) +endif() + + +if(ANDROID) + add_library(extern_glog INTERFACE) +else() # UNIX + add_dependencies(glog gflags) + add_library(extern_glog ALIAS glog) +endif() \ No newline at end of file diff --git a/runtime/cmake/gtest.cmake b/runtime/cmake/gtest.cmake index f3e72d26..6b1eda40 100644 --- a/runtime/cmake/gtest.cmake +++ b/runtime/cmake/gtest.cmake @@ -1,14 +1,26 @@ include(FetchContent) -FetchContent_Declare( - gtest - URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip - URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a -) -FetchContent_MakeAvailable(gtest) -include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) +if(ANDROID) +else() # UNIX + FetchContent_Declare( + gtest + URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip + URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a + ) + FetchContent_MakeAvailable(gtest) + include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) +endif() + + + +if(ANDROID) + add_library(extern_gtest INTERFACE) +else() # UNIX + add_dependencies(gtest gflags gflog) + add_library(extern_gtest ALIAS gtest) +endif() if(WITH_TESTING) enable_testing() diff --git a/runtime/cmake/summary.cmake b/runtime/cmake/summary.cmake new file mode 100644 index 00000000..fd47c6bd --- /dev/null +++ b/runtime/cmake/summary.cmake @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function(pps_summary) + message(STATUS "") + message(STATUS "*************PaddleSpeech Building Summary**********") + message(STATUS " CMake version : ${CMAKE_VERSION}") + message(STATUS " CMake command : ${CMAKE_COMMAND}") + message(STATUS " UNIX : ${UNIX}") + message(STATUS " ANDROID : ${ANDROID}") + message(STATUS " System : ${CMAKE_SYSTEM_NAME}") + message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}") + message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") + message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") + message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") + get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) + message(STATUS " Compile definitions : ${tmp}") + message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") + message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}") + message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}") + message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}") + message(STATUS "") + + message(STATUS " WITH_ASR : ${WITH_ASR}") + message(STATUS " WITH_CLS : ${WITH_CLS}") + message(STATUS " WITH_VAD : ${WITH_VAD}") + message(STATUS " WITH_GPU : ${WITH_GPU}") + message(STATUS " WITH_TESTING : ${WITH_TESTING}") + message(STATUS " WITH_PROFILING : ${WITH_PROFILING}") + message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}") + if(WITH_GPU) + message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}") + endif() + + if(ANDROID) + message(STATUS " ANDROID_ABI : ${ANDROID_ABI}") + message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}") + message(STATUS " ANDROID_NDK : ${ANDROID_NDK}") + message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}") + endif() + if (WITH_ASR) + message(STATUS " Python executable : ${PYTHON_EXECUTABLE}") + message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}") + endif() +endfunction() + +pps_summary() \ No newline at end of file diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc index 3e3ca2c2..fda8aab0 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -22,7 +22,7 @@ #include "decoder/ctc_prefix_beam_search_score.h" #include "utils/math.h" -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; @@ -103,7 +103,7 @@ static bool PrefixScoreCompare( void CTCPrefixBeamSearch::AdvanceDecoding( const std::vector>& logp) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1); diff --git a/runtime/engine/asr/nnet/u2_nnet.cc b/runtime/engine/asr/nnet/u2_nnet.cc index 0795c836..c2062033 100644 --- a/runtime/engine/asr/nnet/u2_nnet.cc +++ b/runtime/engine/asr/nnet/u2_nnet.cc @@ -18,11 +18,11 @@ #include "nnet/u2_nnet.h" -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; -#endif // end USE_PROFILING +#endif // end WITH_PROFILING namespace ppspeech { @@ -30,7 +30,7 @@ namespace ppspeech { void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { paddle::jit::utils::InitKernelSignatureMap(); -#ifdef USE_GPU +#ifdef WITH_GPU dev_ = phi::GPUPlace(); #else dev_ = phi::CPUPlace(); @@ -62,12 +62,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { } void U2Nnet::Warmup() { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup", TracerEventType::UserDefined, 1); #endif { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "warmup-encoder-ctc", TracerEventType::UserDefined, 1); #endif @@ -91,7 +91,7 @@ void U2Nnet::Warmup() { } { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1); #endif auto hyps = @@ -194,7 +194,7 @@ void U2Nnet::ForwardEncoderChunkImpl( const int32& feat_dim, std::vector* out_prob, int32* vocab_dim) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); #endif @@ -222,7 +222,7 @@ void U2Nnet::ForwardEncoderChunkImpl( VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1] << ", " << feats.shape()[2]; -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("feat", std::ios_base::app | std::ios_base::out); path << offset_; @@ -241,7 +241,7 @@ void U2Nnet::ForwardEncoderChunkImpl( #endif // Endocer chunk forward -#ifdef USE_GPU +#ifdef WITH_GPU feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false); att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false; cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false); @@ -258,7 +258,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::vector outputs = forward_encoder_chunk_(inputs); CHECK_EQ(outputs.size(), 3); -#ifdef USE_GPU +#ifdef WITH_GPU paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace()); att_cache_ = outputs[1].copy_to(paddle::CPUPlace()); cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace()); @@ -268,7 +268,7 @@ void U2Nnet::ForwardEncoderChunkImpl( cnn_cache_ = outputs[2]; #endif -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_logits", std::ios_base::app | std::ios_base::out); @@ -298,7 +298,7 @@ void U2Nnet::ForwardEncoderChunkImpl( encoder_outs_.push_back(chunk_out); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_logits_list", std::ios_base::app | std::ios_base::out); @@ -317,7 +317,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#ifdef USE_GPU +#ifdef WITH_GPU #error "Not implementation." @@ -331,7 +331,7 @@ void U2Nnet::ForwardEncoderChunkImpl( CHECK_EQ(outputs.size(), 1); paddle::Tensor ctc_log_probs = outputs[0]; -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_logprob", std::ios_base::app | std::ios_base::out); @@ -353,7 +353,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#endif // end USE_GPU +#endif // end WITH_GPU // Copy to output, (B=1,T,D) std::vector ctc_log_probs_shape = ctc_log_probs.shape(); @@ -370,7 +370,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::memcpy( out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat)); -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_logits_list_ctc", std::ios_base::app | std::ios_base::out); @@ -419,7 +419,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob, void U2Nnet::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1); #endif CHECK(rescoring_score != nullptr); @@ -461,7 +461,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } } -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_logits_concat", std::ios_base::app | std::ios_base::out); @@ -485,7 +485,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_out0", std::ios_base::app | std::ios_base::out); @@ -504,7 +504,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("encoder_out", std::ios_base::app | std::ios_base::out); @@ -535,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, CHECK_EQ(probs_shape[0], num_hyps); CHECK_EQ(probs_shape[1], max_hyps_len); -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("decoder_logprob", std::ios_base::app | std::ios_base::out); @@ -553,7 +553,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("hyps_lens", std::ios_base::app | std::ios_base::out); @@ -569,7 +569,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::stringstream path("hyps_tensor", std::ios_base::app | std::ios_base::out); diff --git a/runtime/engine/cls/CMakeLists.txt b/runtime/engine/cls/CMakeLists.txt index 4d5e0cff..2f3fd22a 100644 --- a/runtime/engine/cls/CMakeLists.txt +++ b/runtime/engine/cls/CMakeLists.txt @@ -1,7 +1 @@ -project(cls) - -include(fastdeploy) -# add_definitions("-DTEST_DEBUG") -# add_definitions("-DPRINT_TIME") - add_subdirectory(nnet) \ No newline at end of file diff --git a/runtime/engine/cls/nnet/CMakeLists.txt b/runtime/engine/cls/nnet/CMakeLists.txt index b4b76120..27f24434 100644 --- a/runtime/engine/cls/nnet/CMakeLists.txt +++ b/runtime/engine/cls/nnet/CMakeLists.txt @@ -1,8 +1,11 @@ -set(srcs panns_nnet.cc panns_interface.cc) +set(srcs + panns_nnet.cc + panns_interface.cc +) add_library(cls SHARED ${srcs}) -target_link_libraries(cls -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils) +target_link_libraries(cls INTERFACE -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils ) set(bin_name panns_nnet_main) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -target_link_libraries(${bin_name} -static-libstdc++;-Wl,-Bsymbolic cls gflags glog) \ No newline at end of file +target_link_libraries(${bin_name} -static-libstdc++;-Wl,-Bsymbolic gflags glog) diff --git a/runtime/engine/cls/nnet/panns_nnet.cc b/runtime/engine/cls/nnet/panns_nnet.cc index 6b8213f6..bd2265d4 100644 --- a/runtime/engine/cls/nnet/panns_nnet.cc +++ b/runtime/engine/cls/nnet/panns_nnet.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "cls/nnet/panns_nnet.h" -#ifdef PRINT_TIME +#ifdef WITH_PROFILING #include "kaldi/base/timer.h" #endif @@ -86,7 +86,7 @@ int ClsNnet::Forward(const char* wav_path, int topk, char* result, int result_max_len) { -#ifdef PRINT_TIME +#ifdef WITH_PROFILING kaldi::Timer timer; timer.Reset(); #endif @@ -105,7 +105,7 @@ int ClsNnet::Forward(const char* wav_path, conf_.wav_normal_, conf_.wav_normal_type_, conf_.wav_norm_mul_factor_); -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::ofstream fp("cls.wavform", std::ios::out); for (int i = 0; i < wavform.size(); ++i) { @@ -114,11 +114,11 @@ int ClsNnet::Forward(const char* wav_path, fp << "\n"; } #endif -#ifdef PRINT_TIME +#ifdef WITH_PROFILING printf("wav read consume: %fs\n", timer.Elapsed()); #endif -#ifdef PRINT_TIME +#ifdef WITH_PROFILING timer.Reset(); #endif @@ -138,7 +138,7 @@ int ClsNnet::Forward(const char* wav_path, feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]); } } -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::ofstream fp("cls.feat", std::ios::out); for (int i = 0; i < num_frames; ++i) { @@ -149,20 +149,20 @@ int ClsNnet::Forward(const char* wav_path, } } #endif -#ifdef PRINT_TIME +#ifdef WITH_PROFILING printf("extract fbank consume: %fs\n", timer.Elapsed()); #endif // infer std::vector model_out; -#ifdef PRINT_TIME +#ifdef WITH_PROFILING timer.Reset(); #endif ModelForward(feats.data(), num_frames, feat_dim, &model_out); -#ifdef PRINT_TIME +#ifdef WITH_PROFILING printf("fast deploy infer consume: %fs\n", timer.Elapsed()); #endif -#ifdef TEST_DEBUG +#ifndef NDEBUG { std::ofstream fp("cls.logits", std::ios::out); for (int i = 0; i < model_out.size(); ++i) { diff --git a/runtime/engine/codelab/CMakeLists.txt b/runtime/engine/codelab/CMakeLists.txt index c8445fb8..13aa5efb 100644 --- a/runtime/engine/codelab/CMakeLists.txt +++ b/runtime/engine/codelab/CMakeLists.txt @@ -1,3 +1,6 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) -add_subdirectory(glog) +if(ANDROID) +else() #Unix + add_subdirectory(glog) +endif() \ No newline at end of file diff --git a/runtime/engine/codelab/glog/CMakeLists.txt b/runtime/engine/codelab/glog/CMakeLists.txt index 08a98641..492e33c6 100644 --- a/runtime/engine/codelab/glog/CMakeLists.txt +++ b/runtime/engine/codelab/glog/CMakeLists.txt @@ -1,8 +1,8 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc) -target_link_libraries(glog_main glog) +target_link_libraries(glog_main extern_glog) add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc) -target_link_libraries(glog_logtostderr_main glog) +target_link_libraries(glog_logtostderr_main extern_glog) diff --git a/runtime/engine/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt index 4f399eea..a2f56f7f 100644 --- a/runtime/engine/common/CMakeLists.txt +++ b/runtime/engine/common/CMakeLists.txt @@ -2,11 +2,14 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../ ) -add_subdirectory(utils) add_subdirectory(base) +add_subdirectory(utils) add_subdirectory(matrix) include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/frontend ) add_subdirectory(frontend) + +add_library(common INTERFACE) +add_definitions(common base utils kaldi-matrix frontend) \ No newline at end of file diff --git a/runtime/engine/common/base/CMakeLists.txt b/runtime/engine/common/base/CMakeLists.txt index ab710874..a49b78bd 100644 --- a/runtime/engine/common/base/CMakeLists.txt +++ b/runtime/engine/common/base/CMakeLists.txt @@ -1,10 +1,20 @@ + + if(WITH_ASR) add_compile_options(-DWITH_ASR) set(PPS_FLAGS_LIB "fst/flags.h") - set(PPS_GLOB_LIB "fst/log.h") else() set(PPS_FLAGS_LIB "gflags/gflags.h") - set(PPS_GLOB_LIB "glog/logging.h") +endif() + +if(ANDROID) + set(PPS_GLOB_LIB "base/log_impl.h") +else() #UNIX + if(WITH_ASR) + set(PPS_GLOB_LIB "fst/log.h") + else() + set(PPS_GLOB_LIB "glog/logging.h") + endif() endif() configure_file( @@ -17,4 +27,16 @@ configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/log.h.in ${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY ) -message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h") \ No newline at end of file +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h") + + +if(ANDROID) + set(csrc + log_impl.cc + glog_utils.cc + ) + add_library(base ${csrc}) +else() # UNIX + set(csrc) + add_library(base INTERFACE) +endif() \ No newline at end of file diff --git a/runtime/engine/common/base/common.h b/runtime/engine/common/base/common.h index d94dc8a8..17560102 100644 --- a/runtime/engine/common/base/common.h +++ b/runtime/engine/common/base/common.h @@ -50,4 +50,4 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" +#include "utils/math.h" \ No newline at end of file diff --git a/runtime/engine/common/base/glog_utils.cc b/runtime/engine/common/base/glog_utils.cc new file mode 100644 index 00000000..4ab3c251 --- /dev/null +++ b/runtime/engine/common/base/glog_utils.cc @@ -0,0 +1,12 @@ + +#include "base/glog_utils.h" + +namespace google { +void InitGoogleLogging(const char* name) { + LOG(INFO) << "dummpy InitGoogleLogging."; +} + +void InstallFailureSignalHandler() { + LOG(INFO) << "dummpy InstallFailureSignalHandler."; +} +} // namespace google diff --git a/runtime/engine/common/base/glog_utils.h b/runtime/engine/common/base/glog_utils.h new file mode 100644 index 00000000..9cffcafb --- /dev/null +++ b/runtime/engine/common/base/glog_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "base/common.h" + +namespace google { +void InitGoogleLogging(const char* name); + +void InstallFailureSignalHandler(); +} // namespace google \ No newline at end of file diff --git a/runtime/engine/common/base/log_impl.cc b/runtime/engine/common/base/log_impl.cc new file mode 100644 index 00000000..8286f1e7 --- /dev/null +++ b/runtime/engine/common/base/log_impl.cc @@ -0,0 +1,105 @@ +#include "base/log.h" + +DEFINE_int32(logtostderr, 0, "logging to stderr"); + +namespace ppspeech { + +static char __progname[] = "paddlespeech"; + +namespace log { + +std::mutex LogMessage::lock_; +std::string LogMessage::s_debug_logfile_(""); +std::string LogMessage::s_info_logfile_(""); +std::string LogMessage::s_warning_logfile_(""); +std::string LogMessage::s_error_logfile_(""); +std::string LogMessage::s_fatal_logfile_(""); + +void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) { + std::stringstream ss; + ss << getpid(); + ss >> *pid; + *proc_name = ::ppspeech::__progname; +} + +LogMessage::LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file /* = false */) + : level_(level), verbose_(verbose), out_to_file_(out_to_file) { + if (FLAGS_logtostderr == 0) { + stream_ = std::shared_ptr(&std::cout); + } else if (FLAGS_logtostderr == 1) { + stream_ = std::shared_ptr(&std::cerr); + } else if (out_to_file_) { + // logfile + lock_.lock(); + init(file, line); + } +} + +LogMessage::~LogMessage() { + stream() << std::endl; + + if (out_to_file_) { + lock_.unlock(); + } + + if (level_ == FATAL) { + std::abort(); + } +} + +void LogMessage::init(const char* file, int line) { + time_t t = time(0); + char tmp[100]; + strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t)); + + if (s_info_logfile_.empty()) { + std::string pid; + std::string proc_name; + get_curr_proc_info(&pid, &proc_name); + + s_debug_logfile_ = + std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid); + s_info_logfile_ = + std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid); + s_warning_logfile_ = + std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid); + s_error_logfile_ = + std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid); + s_fatal_logfile_ = + std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid); + } + + std::ofstream ofs; + if (level_ == DEBUG) { + stream_ = std::make_shared( + s_debug_logfile_.c_str(), std::ios::out | std::ios::app); + // ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == INFO) { + // ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + stream_ = std::make_shared( + s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == WARNING) { + // ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + stream_ = std::make_shared( + s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == ERROR) { + // ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app); + stream_ = std::make_shared( + s_error_logfile_.c_str(), std::ios::out | std::ios::app); + } else { + // ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); + stream_ = std::make_shared( + s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); + } + + // stream_ = &ofs; + + stream() << tmp << " " << file << " line " << line << "; "; + stream() << std::flush; +} +} // namespace log +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/base/log_impl.h b/runtime/engine/common/base/log_impl.h new file mode 100644 index 00000000..93573620 --- /dev/null +++ b/runtime/engine/common/base/log_impl.h @@ -0,0 +1,156 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from https://github.com/Dounm/dlog +// modified form +// https://android.googlesource.com/platform/art/+/806defa/src/logging.h + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "base/common.h" +#include "base/macros.h" +#ifndef WITH_GLOG +#include "base/glog_utils.h" +#endif + +DECLARE_int32(logtostderr); + +namespace ppspeech { + +namespace log { + +enum Severity { + DEBUG, + INFO, + WARNING, + ERROR, + FATAL, + NUM_SEVERITIES, +}; + +class LogMessage { + public: + static void get_curr_proc_info(std::string* pid, std::string* proc_name); + + LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file = false); + + ~LogMessage(); + + std::ostream& stream() { return *stream_; } + + private: + void init(const char* file, int line); + + private: + std::shared_ptr stream_; + Severity level_; + bool verbose_; + bool out_to_file_; + + static std::mutex lock_; // stream write lock + static std::string s_debug_logfile_; + static std::string s_info_logfile_; + static std::string s_warning_logfile_; + static std::string s_error_logfile_; + static std::string s_fatal_logfile_; + + DISALLOW_COPY_AND_ASSIGN(LogMessage); +}; + + +} // namespace log + +} // namespace ppspeech + + +#ifndef NDEBUG +#define DLOG_DEBUG \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::DEBUG, false) +#else +#define DLOG_DEBUG \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::DEBUG, true) +#endif + +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) + +#define DLOG_0 DLOG_DEBUG +#define DLOG_1 DLOG_INFO +#define DLOG_2 DLOG_WARNING +#define DLOG_3 DLOG_ERROR +#define DLOG_4 DLOG_FATAL + +#define LOG(level) DLOG_##level.stream() + +#define VLOG(verboselevel) LOG(verboselevel) + + +#define CHECK(exp) \ + ppspeech::log::LogMessage( \ + __FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \ + .stream() \ + << "Check Failed: " #exp + +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#ifndef NDEBUG +#define DCHECK(x) CHECK(x) +#define DCHECK_EQ(x, y) CHECK_EQ(x, y) +#define DCHECK_NE(x, y) CHECK_NE(x, y) +#define DCHECK_LE(x, y) CHECK_LE(x, y) +#define DCHECK_LT(x, y) CHECK_LT(x, y) +#define DCHECK_GE(x, y) CHECK_GE(x, y) +#define DCHECK_GT(x, y) CHECK_GT(x, y) +#else // NDEBUG +#define DCHECK(condition) \ + while (false) CHECK(condition) +#define DCHECK_EQ(val1, val2) \ + while (false) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) \ + while (false) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) \ + while (false) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) \ + while (false) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) \ + while (false) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) \ + while (false) CHECK_GT(val1, val2) +#define DCHECK_STREQ(str1, str2) \ + while (false) CHECK_STREQ(str1, str2) +#endif \ No newline at end of file diff --git a/runtime/engine/common/base/macros.h b/runtime/engine/common/base/macros.h index db989812..e60baf55 100644 --- a/runtime/engine/common/base/macros.h +++ b/runtime/engine/common/base/macros.h @@ -17,14 +17,14 @@ #include #include -namespace ppspeech { - #ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ void operator=(const TypeName&) = delete #endif +namespace ppspeech { + // kSpaceSymbol in UTF-8 is: ▁ const char kSpaceSymbo[] = "\xe2\x96\x81"; diff --git a/runtime/engine/common/frontend/CMakeLists.txt b/runtime/engine/common/frontend/CMakeLists.txt index 617c35e1..4ff3117c 100644 --- a/runtime/engine/common/frontend/CMakeLists.txt +++ b/runtime/engine/common/frontend/CMakeLists.txt @@ -24,5 +24,5 @@ set(BINS foreach(bin_name IN LISTS BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC frontend utils kaldi-util gflags glog) -endforeach() + target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util gflags extern_glog) +endforeach() \ No newline at end of file diff --git a/runtime/engine/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc index 90817f0d..1d0b55b9 100644 --- a/runtime/engine/common/matrix/kaldi-vector.cc +++ b/runtime/engine/common/matrix/kaldi-vector.cc @@ -584,7 +584,7 @@ template void VectorBase::CopyColFromMat(const MatrixBase &mat, // Real *data = data_; //// implement the function according to a dimension cutoff for computation -///efficiency +/// efficiency // if (num_rows <= 64) { // cblas_Xscal(dim, beta, data, 1); // const Real *m_data = M.Data(); @@ -605,7 +605,7 @@ template void VectorBase::CopyColFromMat(const MatrixBase &mat, // MatrixIndexT num_cols = M.NumCols(); //// implement the function according to a dimension cutoff for computation -///efficiency +/// efficiency // if (num_cols <= 64) { // for (MatrixIndexT i = 0; i < dim_; i++) { // double sum = 0.0; @@ -1224,7 +1224,7 @@ void Vector::Swap(Vector *other) { // for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += // N_col_stride, data++) { //*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, -//Ndata, N_row_stride); +// Ndata, N_row_stride); //} //} diff --git a/runtime/engine/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt index 8589b19a..eb3c7197 100644 --- a/runtime/engine/common/utils/CMakeLists.txt +++ b/runtime/engine/common/utils/CMakeLists.txt @@ -1,20 +1,27 @@ -add_library(utils + +set(csrc file_utils.cc math.cc strings.cc audio_process.cc ) +add_library(utils ${csrc}) if(WITH_TESTING) enable_testing() - link_libraries(gtest_main gmock) + + if(ANDROID) + else() # UNIX + link_libraries(gtest_main gmock) + + add_executable(strings_test strings_test.cc) + target_link_libraries(strings_test PUBLIC utils) + add_test( + NAME strings_test + COMMAND strings_test + ) + endif() +endif() - add_executable(strings_test strings_test.cc) - target_link_libraries(strings_test PUBLIC utils) - add_test( - NAME strings_test - COMMAND strings_test - ) -endif() \ No newline at end of file diff --git a/runtime/engine/vad/CMakeLists.txt b/runtime/engine/vad/CMakeLists.txt index d13cc407..4e9f448c 100644 --- a/runtime/engine/vad/CMakeLists.txt +++ b/runtime/engine/vad/CMakeLists.txt @@ -1,18 +1,5 @@ -# set(CMAKE_CXX_STANDARD 11) -# # 指定下载解压后的fastdeploy库路径 -# set(FASTDEPLOY_INSTALL_DIR "fdlib/fastdeploy-linux-x64-1.0.4" CACHE STRING force) -# if(NOT EXISTS ${FASTDEPLOY_INSTALL_DIR}) -# message(FATAL_ERROR "Please using cmake -B build -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR}") -# endif() - -# include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) - -# # 添加FastDeploy依赖头文件 -# include_directories(${FASTDEPLOY_INCS}) - -add_executable(infer_onnx_silero_vad ${CMAKE_CURRENT_SOURCE_DIR}/infer_onnx_silero_vad.cc wav.h vad.cc vad.h) - -# 添加FastDeploy库依赖 -target_link_libraries(infer_onnx_silero_vad ${FASTDEPLOY_LIBS}) +set(bin_name silero_vad_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc vad.cc) +target_link_libraries(${bin_name} ${FASTDEPLOY_LIBS} gflags extern_glog) diff --git a/runtime/engine/vad/infer_onnx_silero_vad.cc b/runtime/engine/vad/silero_vad_main.cc similarity index 100% rename from runtime/engine/vad/infer_onnx_silero_vad.cc rename to runtime/engine/vad/silero_vad_main.cc diff --git a/runtime/examples/silero_vad/.gitignore b/runtime/examples/silero_vad/.gitignore new file mode 100644 index 00000000..1269488f --- /dev/null +++ b/runtime/examples/silero_vad/.gitignore @@ -0,0 +1 @@ +data diff --git a/runtime/engine/vad/README.md b/runtime/examples/silero_vad/README.md similarity index 100% rename from runtime/engine/vad/README.md rename to runtime/examples/silero_vad/README.md diff --git a/runtime/engine/vad/README_CN.md b/runtime/examples/silero_vad/README_CN.md similarity index 100% rename from runtime/engine/vad/README_CN.md rename to runtime/examples/silero_vad/README_CN.md diff --git a/runtime/examples/silero_vad/local/build.sh b/runtime/examples/silero_vad/local/build.sh new file mode 100755 index 00000000..d35de5a2 --- /dev/null +++ b/runtime/examples/silero_vad/local/build.sh @@ -0,0 +1,14 @@ +ANDROID_NDK=/workspace/zhanghui/android-sdk/android-ndk-r25c +ANDROID_TOOLCHAIN=clang +FASTDEPLOY_INSTALL_DIR=./fdlib/fastdeploy-android-1.0.3-shared/ +TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + +cmake -B build -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_NDK=${ANDROID_NDK} \ + -DANDROID_PLATFORM="android-21" \ + -DANDROID_STL=c++_shared \ + -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -Wno-dev diff --git a/runtime/examples/silero_vad/local/build_android.sh b/runtime/examples/silero_vad/local/build_android.sh new file mode 100755 index 00000000..02aa6f9f --- /dev/null +++ b/runtime/examples/silero_vad/local/build_android.sh @@ -0,0 +1,31 @@ +ANDROID_NDK=/workspace/zhanghui/android-sdk/android-ndk-r25c +FASTDEPLOY_INSTALL_DIR=./fdlib/fastdeploy-android-1.0.4-shared/ + +# Setting up Android toolchanin +ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a' +ANDROID_PLATFORM="android-21" # API >= 21 +ANDROID_STL=c++_shared # 'c++_shared', 'c++_static' +ANDROID_TOOLCHAIN=clang # 'clang' only +TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + +# Create build directory +BUILD_ROOT=build/Android +BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21 +#FASDEPLOY_INSTALL_DIR="${BUILD_DIR}/install" +#mkdir build && mkdir ${BUILD_ROOT} && mkdir ${BUILD_DIR} +mkdir -p ${BUILD_DIR} +cd ${BUILD_DIR} + +# CMake configuration with Android toolchain +cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DANDROID_ABI=${ANDROID_ABI} \ + -DANDROID_NDK=${ANDROID_NDK} \ + -DANDROID_PLATFORM=${ANDROID_PLATFORM} \ + -DANDROID_STL=${ANDROID_STL} \ + -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -Wno-dev ../../.. + +# Build FastDeploy Android C++ SDK +make -j8 diff --git a/runtime/examples/silero_vad/local/decode.sh b/runtime/examples/silero_vad/local/decode.sh new file mode 100755 index 00000000..e69de29b diff --git a/runtime/examples/silero_vad/local/download.sh b/runtime/examples/silero_vad/local/download.sh new file mode 100755 index 00000000..2f55e20a --- /dev/null +++ b/runtime/examples/silero_vad/local/download.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +mkdir -p data +cd data + +wget -c https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz + +test -e silero_vad || tar zxvf silero_vad.tgz + +wget -c https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad_sample.wav diff --git a/runtime/examples/silero_vad/path.sh b/runtime/examples/silero_vad/path.sh new file mode 100644 index 00000000..ad3a7358 --- /dev/null +++ b/runtime/examples/silero_vad/path.sh @@ -0,0 +1,18 @@ +# This contains the locations of binarys build required for running the examples. + +unset GREP_OPTIONS + +ENGINE_ROOT=$PWD/../../../ +ENGINE_BUILD=$ENGINE_ROOT/build/engine/asr + +ENGINE_TOOLS=$ENGINE_ROOT/tools +TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin + +[ -d $ENGINE_BUILD ] || { echo "Error: 'build/runtime' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/decoder:$ENGINE_BUILD/../common/frontend/audio:$ENGINE_BUILD/recognizer + +#PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") +export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/runtime/examples/silero_vad/run.sh b/runtime/examples/silero_vad/run.sh new file mode 100644 index 00000000..9707df1b --- /dev/null +++ b/runtime/examples/silero_vad/run.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -e + +. path.sh + +nj=40 +stage=-1 +stop_stage=100 + +. utils/parse_options.sh + +# input +data=data +exp=exp +mkdir -p $exp $data + +# 1. compile +if [ ! -d ${SPEECHX_BUILD} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + + # build for android armv8/armv7 + # bash build_android.sh + popd +fi + +ckpt_dir=$data/silero_vad +wav=$data/silero_vad_sample.wav + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then + ./local/download.sh +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + ./local/decode.sh +fi + diff --git a/runtime/examples/silero_vad/utils b/runtime/examples/silero_vad/utils new file mode 120000 index 00000000..973afe67 --- /dev/null +++ b/runtime/examples/silero_vad/utils @@ -0,0 +1 @@ +../../../utils \ No newline at end of file From 2beb7ffce0fce64f92d417c4828b155bad73e4b6 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 6 Mar 2023 13:43:46 +0800 Subject: [PATCH 15/50] fix asr compile bug (#2993) --- runtime/CMakeLists.txt | 94 +++++++++++++++++-------------------- runtime/build.sh | 2 +- runtime/cmake/gtest.cmake | 2 +- runtime/cmake/openfst.cmake | 1 + 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index bdce2046..af970526 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -16,15 +16,8 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() - project(paddlespeech VERSION 0.1) -# if(ANDROID) -# # when cross compile with ndk under linux, -# # UNIX and ANROID are all True -# set(UNIX) -# endif() - include(FetchContent) include(ExternalProject) @@ -71,60 +64,60 @@ if(WITH_TESTING) endif() # fastdeploy -include(fastdeploy) +if(NOT WITH_ASR) + include(fastdeploy) +endif() if(WITH_ASR) # openfst include(openfst) - add_dependencies(openfst gflags glog) + add_dependencies(openfst gflags extern_glog) endif() ############################################################################### # Find Package ############################################################################### - -# python/pybind11/threads find_package(Threads REQUIRED) -# https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 -find_package(Python3 COMPONENTS Interpreter Development) -find_package(pybind11 CONFIG) - if(WITH_ASR) + # https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 + find_package(Python3 COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG) + if(Python3_FOUND) - message(STATUS "Python3_FOUND = ${Python3_FOUND}") - message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") - message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") - message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") - message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") - set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) - set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) + message(STATUS "Python3_FOUND = ${Python3_FOUND}") + message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") + message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") + message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") + message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") + set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) + set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) endif() message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}") if(pybind11_FOUND) - message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") - message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") - message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") + message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") + message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") + message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") endif() # paddle libpaddle.so # paddle include and link option # -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])" + "out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)" + ) execute_process( - COMMAND python -c "\ - import os;\ - import paddle;\ - include_dir=paddle.sysconfig.get_include();\ - paddle_dir=os.path.split(include_dir)[0];\ - libs_dir=os.path.join(paddle_dir, 'libs');\ - fluid_dir=os.path.join(paddle_dir, 'fluid');\ - out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\ - out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\ - " + COMMAND python -c "${EXECUTE_COMMAND}" OUTPUT_VARIABLE PADDLE_LINK_FLAGS RESULT_VARIABLE SUCESS) @@ -133,29 +126,28 @@ if(WITH_ASR) # paddle compile option # -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include + set(EXECUTE_COMMAND "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "print(f\"-I{include_dir}\")" + ) execute_process( - COMMAND python -c "\ - import paddle; \ - include_dir = paddle.sysconfig.get_include(); \ - print(f\"-I{include_dir}\"); \ - " + COMMAND python -c "${EXECUTE_COMMAND}" OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS) message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS}) string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS) - # for LD_LIBRARY_PATH # set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/) + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir=paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=':'.join([libs_dir, fluid_dir]); print(out)" + ) execute_process( - COMMAND python -c "\ - import os; \ - import paddle; \ - include_dir=paddle.sysconfig.get_include(); \ - paddle_dir=os.path.split(include_dir)[0]; \ - libs_dir=os.path.join(paddle_dir, 'libs'); \ - fluid_dir=os.path.join(paddle_dir, 'fluid'); \ - out=':'.join([libs_dir, fluid_dir]); print(out); \ - " + COMMAND python -c "${EXECUTE_COMMAND}" OUTPUT_VARIABLE PADDLE_LIB_DIRS) message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) endif() @@ -167,4 +159,4 @@ include(summary) ############################################################################### set(ENGINE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/engine) -add_subdirectory(engine) +add_subdirectory(engine) \ No newline at end of file diff --git a/runtime/build.sh b/runtime/build.sh index 131fb7f1..f7d0a2b2 100755 --- a/runtime/build.sh +++ b/runtime/build.sh @@ -4,5 +4,5 @@ set -xe # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -cmake -B build -DWITH_ASR=OFF -DWITH_CLS=OFF +cmake -B build -DWITH_ASR=ON -DWITH_CLS=OFF -DWITH_VAD=OFF cmake --build build -j diff --git a/runtime/cmake/gtest.cmake b/runtime/cmake/gtest.cmake index 6b1eda40..a311721f 100644 --- a/runtime/cmake/gtest.cmake +++ b/runtime/cmake/gtest.cmake @@ -18,7 +18,7 @@ endif() if(ANDROID) add_library(extern_gtest INTERFACE) else() # UNIX - add_dependencies(gtest gflags gflog) + add_dependencies(gtest gflags extern_glog) add_library(extern_gtest ALIAS gtest) endif() diff --git a/runtime/cmake/openfst.cmake b/runtime/cmake/openfst.cmake index 2e2f82f2..06697156 100644 --- a/runtime/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -16,6 +16,7 @@ ExternalProject_Add(openfst PREFIX ${openfst_PREFIX_DIR} SOURCE_DIR ${openfst_SOURCE_DIR} BINARY_DIR ${openfst_BINARY_DIR} + BUILD_ALWAYS 0 CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" From bf914a9c8b5b01c37932cfc63fae46ce3fa83928 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 13 Mar 2023 14:45:22 +0800 Subject: [PATCH 16/50] [runtime] optimization compile and add vad interface (#3026) * vad recipe ok * refactor vad, add vad conf, vad inerface, vad recipe * format * install vad lib/bin/inc * using cpack * add vad doc, fix vad state name * add comment * refactor fastdeploy download * add vad jni; format code * add timer; compute vad rtf; vad add beam param * andorid find library * fix log; add vad rtf * fix glog * fix BUILD_TYPE bug * update doc * rm jni --- runtime/CMakeLists.txt | 67 +++-- runtime/build.sh | 16 +- runtime/build_android.sh | 9 +- runtime/cmake/fastdeploy.cmake | 139 +++++++--- runtime/cmake/gflags.cmake | 2 + runtime/cmake/glog.cmake | 16 +- runtime/cmake/openfst.cmake | 10 + runtime/cmake/summary.cmake | 6 + .../ctc_prefix_beam_search_decoder_main.cc | 8 +- .../asr/decoder/ctc_tlg_decoder_main.cc | 2 +- runtime/engine/asr/nnet/nnet_producer.cc | 3 +- runtime/engine/asr/nnet/u2_nnet_main.cc | 8 +- .../engine/asr/nnet/u2_nnet_thread_main.cc | 2 +- .../recognizer/u2_recognizer_batch_main.cc | 2 +- .../asr/recognizer/u2_recognizer_main.cc | 2 +- .../recognizer/u2_recognizer_thread_main.cc | 2 +- .../server/websocket/websocket_client_main.cc | 2 +- .../server/websocket/websocket_server_main.cc | 2 +- runtime/engine/cls/nnet/CMakeLists.txt | 2 +- runtime/engine/cls/nnet/panns_interface.cc | 1 + runtime/engine/cls/nnet/panns_nnet_main.cc | 1 + runtime/engine/common/CMakeLists.txt | 6 +- runtime/engine/common/base/CMakeLists.txt | 1 + runtime/engine/common/base/basic_types.h | 2 +- runtime/engine/common/base/common.h | 3 +- runtime/engine/common/base/config.h | 13 +- runtime/engine/common/base/log_impl.cc | 40 +-- runtime/engine/common/base/log_impl.h | 55 ++-- runtime/engine/common/frontend/CMakeLists.txt | 6 +- runtime/engine/common/frontend/assembler.cc | 3 +- runtime/engine/common/frontend/fftsg.c | 24 +- runtime/engine/common/frontend/rfft.cc | 3 +- runtime/engine/common/frontend/wave-reader.cc | 10 +- runtime/engine/common/matrix/kaldi-matrix.h | 2 +- runtime/engine/common/matrix/kaldi-vector.cc | 2 + runtime/engine/common/matrix/matrix-common.h | 2 +- runtime/engine/common/utils/CMakeLists.txt | 1 + runtime/engine/common/utils/timer.cc | 63 +++++ runtime/engine/common/utils/timer.h | 39 +++ runtime/engine/vad/CMakeLists.txt | 8 +- runtime/engine/vad/{ => frontend}/wav.h | 2 + runtime/engine/vad/interface/CMakeLists.txt | 25 ++ runtime/engine/vad/interface/vad_interface.cc | 94 +++++++ runtime/engine/vad/interface/vad_interface.h | 46 +++ .../vad/interface/vad_interface_main.cc | 71 +++++ runtime/engine/vad/nnet/CMakeLists.txt | 16 ++ runtime/engine/vad/{ => nnet}/vad.cc | 93 ++++--- runtime/engine/vad/{ => nnet}/vad.h | 50 +++- .../vad_nnet_main.cc} | 33 ++- runtime/examples/silero_vad/README.md | 121 -------- runtime/examples/silero_vad/README_CN.md | 119 -------- runtime/examples/silero_vad/local/decode.sh | 0 runtime/examples/silero_vad/path.sh | 18 -- runtime/examples/u2pp_ol/wenetspeech/path.sh | 2 +- .../examples/{silero_vad => vad}/.gitignore | 0 runtime/examples/vad/README.md | 261 ++++++++++++++++++ runtime/examples/vad/conf/vad.ini | 11 + .../{silero_vad => vad}/local/build.sh | 0 .../local/build_android.sh | 0 runtime/examples/vad/local/decode.sh | 23 ++ .../{silero_vad => vad}/local/download.sh | 0 runtime/examples/vad/path.sh | 17 ++ runtime/examples/{silero_vad => vad}/run.sh | 6 +- runtime/examples/{silero_vad => vad}/utils | 0 64 files changed, 1128 insertions(+), 465 deletions(-) create mode 100644 runtime/engine/common/utils/timer.cc create mode 100644 runtime/engine/common/utils/timer.h rename runtime/engine/vad/{ => frontend}/wav.h (99%) create mode 100644 runtime/engine/vad/interface/CMakeLists.txt create mode 100644 runtime/engine/vad/interface/vad_interface.cc create mode 100644 runtime/engine/vad/interface/vad_interface.h create mode 100644 runtime/engine/vad/interface/vad_interface_main.cc create mode 100644 runtime/engine/vad/nnet/CMakeLists.txt rename runtime/engine/vad/{ => nnet}/vad.cc (80%) rename runtime/engine/vad/{ => nnet}/vad.h (78%) rename runtime/engine/vad/{silero_vad_main.cc => nnet/vad_nnet_main.cc} (58%) delete mode 100644 runtime/examples/silero_vad/README.md delete mode 100644 runtime/examples/silero_vad/README_CN.md delete mode 100755 runtime/examples/silero_vad/local/decode.sh delete mode 100644 runtime/examples/silero_vad/path.sh rename runtime/examples/{silero_vad => vad}/.gitignore (100%) create mode 100644 runtime/examples/vad/README.md create mode 100644 runtime/examples/vad/conf/vad.ini rename runtime/examples/{silero_vad => vad}/local/build.sh (100%) rename runtime/examples/{silero_vad => vad}/local/build_android.sh (100%) create mode 100755 runtime/examples/vad/local/decode.sh rename runtime/examples/{silero_vad => vad}/local/download.sh (100%) create mode 100644 runtime/examples/vad/path.sh rename runtime/examples/{silero_vad => vad}/run.sh (77%) mode change 100644 => 100755 rename runtime/examples/{silero_vad => vad}/utils (100%) diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index af970526..efeb6218 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -1,4 +1,5 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON +cmake_minimum_required(VERSION 3.17 FATAL_ERROR) set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0048.cmake") @@ -6,20 +7,12 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") include(system) -# Ninja Generator will set CMAKE_BUILD_TYPE to Debug -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE - "Release" - CACHE - STRING - "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" - FORCE) -endif() - project(paddlespeech VERSION 0.1) -include(FetchContent) -include(ExternalProject) +set(PPS_VERSION_MAJOR 1) +set(PPS_VERSION_MINOR 0) +set(PPS_VERSION_PATCH 0) +set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}") # fc_patch dir set(FETCHCONTENT_QUIET off) @@ -27,21 +20,36 @@ get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR set(FETCHCONTENT_BASE_DIR ${fc_patch}) set(CMAKE_VERBOSE_MAKEFILE ON) +set(CMAKE_FIND_DEBUG_MODE OFF) set(PPS_CXX_STANDARD 14) # set std-14 set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) -add_compile_options(-fPIC) -# compiler option -# Keep the same with openfst, -fPIC or -fpic -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") -SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") -SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") +# Ninja Generator will set CMAKE_BUILD_TYPE to Debug +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE) +endif() +# find_* e.g. find_library work when Cross-Compiling +if(ANDROID) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) +endif() + +# install dir into `build/install` +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install) + +include(FetchContent) +include(ExternalProject) ############################################################################### # Option Configurations ############################################################################### +# https://github.com/google/brotli/pull/655 +option(BUILD_SHARED_LIBS "Build shared libraries" ON) + option(WITH_ASR "build asr" ON) option(WITH_CLS "build cls" ON) option(WITH_VAD "build vad" ON) @@ -77,6 +85,7 @@ endif() ############################################################################### # Find Package ############################################################################### +# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 find_package(Threads REQUIRED) if(WITH_ASR) @@ -157,6 +166,22 @@ include(summary) ############################################################################### # Add local library ############################################################################### -set(ENGINE_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/engine) +set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine) -add_subdirectory(engine) \ No newline at end of file +add_subdirectory(engine) + + +############################################################################### +# CPack library +############################################################################### +# build a CPack driven installer package +include (InstallRequiredSystemLibraries) +set(CPACK_PACKAGE_NAME "paddlespeech_library") +set(CPACK_PACKAGE_VENDOR "paddlespeech") +set(CPACK_PACKAGE_VERSION_MAJOR 1) +set(CPACK_PACKAGE_VERSION_MINOR 0) +set(CPACK_PACKAGE_VERSION_PATCH 0) +set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library") +set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com") +set(CPACK_SOURCE_GENERATOR "TGZ") +include (CPack) \ No newline at end of file diff --git a/runtime/build.sh b/runtime/build.sh index f7d0a2b2..4a27766a 100755 --- a/runtime/build.sh +++ b/runtime/build.sh @@ -1,8 +1,20 @@ #!/usr/bin/env bash set -xe +BUILD_ROOT=build/Linux +BUILD_DIR=${BUILD_ROOT}/x86_64 + +mkdir -p ${BUILD_DIR} + # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -cmake -B build -DWITH_ASR=ON -DWITH_CLS=OFF -DWITH_VAD=OFF -cmake --build build -j +#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install +cmake -B ${BUILD_DIR} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Linux/x86_64/install +cmake --build ${BUILD_DIR} -j diff --git a/runtime/build_android.sh b/runtime/build_android.sh index 64a33762..ac3980a8 100755 --- a/runtime/build_android.sh +++ b/runtime/build_android.sh @@ -14,8 +14,8 @@ TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake # Create build directory BUILD_ROOT=build/Android BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21 -#FASDEPLOY_INSTALL_DIR="${BUILD_DIR}/install" -#mkdir build && mkdir ${BUILD_ROOT} && mkdir ${BUILD_DIR} +FASTDEPLOY_INSTALL_DIR="/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install" + mkdir -p ${BUILD_DIR} cd ${BUILD_DIR} @@ -27,10 +27,13 @@ cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ -DANDROID_PLATFORM=${ANDROID_PLATFORM} \ -DANDROID_STL=${ANDROID_STL} \ -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DBUILD_SHARED_LIBS=OFF \ -DWITH_ASR=OFF \ -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -DCMAKE_FIND_DEBUG_MODE=OFF \ -Wno-dev ../../.. - #-DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ # Build FastDeploy Android C++ SDK make diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake index 463a8e8e..b7c9a8dd 100644 --- a/runtime/cmake/fastdeploy.cmake +++ b/runtime/cmake/fastdeploy.cmake @@ -1,42 +1,119 @@ -set(ARCH "mserver_x86_64" CACHE STRING "Target Architecture: -android_arm, android_armv7, android_armv8, android_x86, android_x86_64, -mserver_x86_64, ubuntu_x86_64, ios_armv7, ios_armv7s, ios_armv8, ios_x86_64, ios_x86, -windows_x86") - -set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy) -if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz) - exec_program("mkdir -p ${FASTDEPLOY_DIR} && - wget -c https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz -P ${FASTDEPLOY_DIR} && - tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4.tgz -C ${FASTDEPLOY_DIR} && - mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.4 ${FASTDEPLOY_DIR}/linux-x64") -endif() +include(FetchContent) -if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared.tgz) - exec_program("mkdir -p ${FASTDEPLOY_DIR} && - wget -c https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz -P ${FASTDEPLOY_DIR} && - tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared.tgz -C ${FASTDEPLOY_DIR} && - mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.4-shared ${FASTDEPLOY_DIR}/android-armv7v8") -endif() +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_PATCH 1 + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_INSTALL 1 + LOG_TEST 1 # Wrap test in script to log output + LOG_MERGED_STDOUTERR 1 + LOG_OUTPUT_ON_FAILURE 1 +) + +if(NOT FASTDEPLOY_INSTALL_DIR) + if(ANDROID) + FetchContent_Declare( + fastdeploy + URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz + URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + else() # Linux + FetchContent_Declare( + fastdeploy + URL https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.4.tgz + URL_HASH MD5=125df3bfce603521960cc5c8b47faab0 + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") + # add_definitions("-DUSE_ORT_BACKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") + endif() -if(ANDROID) - set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8) - add_definitions("-DUSE_PADDLE_LITE_BAKEND") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") -elseif(UNIX) - set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/linux-x64) - add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") - # add_definitions("-DUSE_ORT_BACKEND") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") + FetchContent_MakeAvailable(fastdeploy) + + set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src) endif() -message(STATUS "FASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} ${UNIX}") include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) # fix compiler flags conflict, since fastdeploy using c++11 for project +# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)` set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) include_directories(${FASTDEPLOY_INCS}) -message(STATUS "FASTDEPLOY_INCS=${FASTDEPLOY_INCS}") \ No newline at end of file + +# install fastdeploy and dependents lib +# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) +# No dynamic libs need to install while using +# FastDeploy static lib. +if(ANDROID AND WITH_ANDROID_STATIC_LIB) + return() +endif() + +set(DYN_LIB_SUFFIX "*.so*") +if(WIN32) + set(DYN_LIB_SUFFIX "*.dll") +elseif(APPLE) + set(DYN_LIB_SUFFIX "*.dylib*") +endif() + +if(FastDeploy_DIR) + set(DYN_SEARCH_DIR ${FastDeploy_DIR}) +elseif(FASTDEPLOY_INSTALL_DIR) + set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR}) +else() + message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.") +endif() + +file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX}) +file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX}) + +if(ENABLE_VISION) + # OpenCV + if(ANDROID) + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX}) + else() + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX}) + endif() + + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS}) + + if(WIN32) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC)) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + else() # linux/mac + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + endif() + + # FlyCV + if(ENABLE_FLYCV) + file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX}) + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS}) + if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC)) + install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib) + endif() + endif() +endif() + +if(ENABLE_OPENVINO_BACKEND) + # need plugins.xml for openvino backend + set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin) + file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml) + install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib) +endif() + +# Install other libraries +install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib) +install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib) diff --git a/runtime/cmake/gflags.cmake b/runtime/cmake/gflags.cmake index d01eaf60..8ddf6635 100644 --- a/runtime/cmake/gflags.cmake +++ b/runtime/cmake/gflags.cmake @@ -9,3 +9,5 @@ FetchContent_MakeAvailable(gflags) # openfst need include_directories(${gflags_BINARY_DIR}/include) + +install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib) \ No newline at end of file diff --git a/runtime/cmake/glog.cmake b/runtime/cmake/glog.cmake index cbb97d2d..51d0ef06 100644 --- a/runtime/cmake/glog.cmake +++ b/runtime/cmake/glog.cmake @@ -7,6 +7,19 @@ else() # UNIX glog URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DWITH_GFLAGS=OFF + -DBUILD_TESTING=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} ) FetchContent_MakeAvailable(glog) include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) @@ -15,7 +28,8 @@ endif() if(ANDROID) add_library(extern_glog INTERFACE) + add_dependencies(extern_glog gflags) else() # UNIX - add_dependencies(glog gflags) add_library(extern_glog ALIAS glog) + add_dependencies(extern_glog gflags) endif() \ No newline at end of file diff --git a/runtime/cmake/openfst.cmake b/runtime/cmake/openfst.cmake index 06697156..a859076f 100644 --- a/runtime/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -10,9 +10,19 @@ include(FetchContent) #Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in #Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org. +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_TEST 1 # Wrap test in script to log output + LOG_INSTALL 1 # Wrap install in script to log output +) + ExternalProject_Add(openfst URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 + ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${openfst_PREFIX_DIR} SOURCE_DIR ${openfst_SOURCE_DIR} BINARY_DIR ${openfst_BINARY_DIR} diff --git a/runtime/cmake/summary.cmake b/runtime/cmake/summary.cmake index fd47c6bd..95ee324a 100644 --- a/runtime/cmake/summary.cmake +++ b/runtime/cmake/summary.cmake @@ -15,6 +15,7 @@ function(pps_summary) message(STATUS "") message(STATUS "*************PaddleSpeech Building Summary**********") + message(STATUS " PPS_VERSION : ${PPS_VERSION}") message(STATUS " CMake version : ${CMAKE_VERSION}") message(STATUS " CMake command : ${CMAKE_COMMAND}") message(STATUS " UNIX : ${UNIX}") @@ -24,10 +25,13 @@ function(pps_summary) message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") + message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}") get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) message(STATUS " Compile definitions : ${tmp}") message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") + message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}") message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}") + message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}") message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}") message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}") message(STATUS "") @@ -39,6 +43,8 @@ function(pps_summary) message(STATUS " WITH_TESTING : ${WITH_TESTING}") message(STATUS " WITH_PROFILING : ${WITH_PROFILING}") message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}") + message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}") + message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}") if(WITH_GPU) message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}") endif() diff --git a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index bd73b3ac..1673bdad 100644 --- a/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "decoder/ctc_prefix_beam_search_decoder.h" #include "base/common.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" #include "frontend/data_cache.h" #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" @@ -117,9 +117,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) << "utt: " << utt << " skip last " - << this_chunk_size << " frames, expect is " - << receptive_field_length; + LOG(WARNING) + << "utt: " << utt << " skip last " << this_chunk_size + << " frames, expect is " << receptive_field_length; break; } diff --git a/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc index 148ee15e..410574dc 100644 --- a/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc @@ -14,8 +14,8 @@ // todo refactor, repalce with gtest -#include "decoder/ctc_tlg_decoder.h" #include "base/common.h" +#include "decoder/ctc_tlg_decoder.h" #include "decoder/param.h" #include "frontend/data_cache.h" #include "kaldi/util/table-types.h" diff --git a/runtime/engine/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc index 29daa709..1e481e30 100644 --- a/runtime/engine/asr/nnet/nnet_producer.cc +++ b/runtime/engine/asr/nnet/nnet_producer.cc @@ -13,12 +13,13 @@ // limitations under the License. #include "nnet/nnet_producer.h" + #include "matrix/kaldi-matrix.h" namespace ppspeech { -using std::vector; using kaldi::BaseFloat; +using std::vector; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) diff --git a/runtime/engine/asr/nnet/u2_nnet_main.cc b/runtime/engine/asr/nnet/u2_nnet_main.cc index 699f4258..e60ae7e8 100644 --- a/runtime/engine/asr/nnet/u2_nnet_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_main.cc @@ -13,13 +13,13 @@ // limitations under the License. -#include "nnet/u2_nnet.h" #include "base/common.h" #include "decoder/param.h" #include "frontend/assembler.h" #include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); @@ -93,9 +93,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) << "utt: " << utt << " skip last " - << this_chunk_size << " frames, expect is " - << receptive_field_length; + LOG(WARNING) + << "utt: " << utt << " skip last " << this_chunk_size + << " frames, expect is " << receptive_field_length; break; } diff --git a/runtime/engine/asr/nnet/u2_nnet_thread_main.cc b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc index 4339bdbe..c3f291ce 100644 --- a/runtime/engine/asr/nnet/u2_nnet_thread_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc @@ -13,7 +13,6 @@ // limitations under the License. -#include "nnet/u2_nnet.h" #include "base/common.h" #include "decoder/param.h" #include "frontend/feature_pipeline.h" @@ -21,6 +20,7 @@ #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/nnet_producer.h" +#include "nnet/u2_nnet.h" DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); diff --git a/runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc index 709e5aa6..8d1532bd 100644 --- a/runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_batch_main.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "recognizer/u2_recognizer.h" #include "common/base/thread_pool.h" #include "common/utils/file_utils.h" #include "common/utils/strings.h" @@ -20,6 +19,7 @@ #include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" #include "nnet/u2_nnet.h" +#include "recognizer/u2_recognizer.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/runtime/engine/asr/recognizer/u2_recognizer_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_main.cc index fb37d050..178c91db 100644 --- a/runtime/engine/asr/recognizer/u2_recognizer_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_main.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "recognizer/u2_recognizer.h" #include "decoder/param.h" #include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" +#include "recognizer/u2_recognizer.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc b/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc index b86853fa..272defc6 100644 --- a/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc +++ b/runtime/engine/asr/recognizer/u2_recognizer_thread_main.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "recognizer/u2_recognizer.h" #include "decoder/param.h" #include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" +#include "recognizer/u2_recognizer.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); diff --git a/runtime/engine/asr/server/websocket/websocket_client_main.cc b/runtime/engine/asr/server/websocket/websocket_client_main.cc index 7ad36e3a..7c5a4f2f 100644 --- a/runtime/engine/asr/server/websocket/websocket_client_main.cc +++ b/runtime/engine/asr/server/websocket/websocket_client_main.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "websocket/websocket_client.h" #include "kaldi/feat/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" +#include "websocket/websocket_client.h" DEFINE_string(host, "127.0.0.1", "host of websocket server"); DEFINE_int32(port, 8082, "port of websocket server"); diff --git a/runtime/engine/asr/server/websocket/websocket_server_main.cc b/runtime/engine/asr/server/websocket/websocket_server_main.cc index 5f805ac9..5c32caf2 100644 --- a/runtime/engine/asr/server/websocket/websocket_server_main.cc +++ b/runtime/engine/asr/server/websocket/websocket_server_main.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "websocket/websocket_server.h" #include "decoder/param.h" +#include "websocket/websocket_server.h" DEFINE_int32(port, 8082, "websocket listening port"); diff --git a/runtime/engine/cls/nnet/CMakeLists.txt b/runtime/engine/cls/nnet/CMakeLists.txt index 27f24434..d331d31a 100644 --- a/runtime/engine/cls/nnet/CMakeLists.txt +++ b/runtime/engine/cls/nnet/CMakeLists.txt @@ -3,7 +3,7 @@ set(srcs panns_interface.cc ) -add_library(cls SHARED ${srcs}) +add_library(cls ${srcs}) target_link_libraries(cls INTERFACE -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils ) set(bin_name panns_nnet_main) diff --git a/runtime/engine/cls/nnet/panns_interface.cc b/runtime/engine/cls/nnet/panns_interface.cc index 257ee44f..cfff3f92 100644 --- a/runtime/engine/cls/nnet/panns_interface.cc +++ b/runtime/engine/cls/nnet/panns_interface.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "cls/nnet/panns_interface.h" + #include "cls/nnet/panns_nnet.h" #include "common/base/config.h" diff --git a/runtime/engine/cls/nnet/panns_nnet_main.cc b/runtime/engine/cls/nnet/panns_nnet_main.cc index 4280d14c..14f91fc7 100644 --- a/runtime/engine/cls/nnet/panns_nnet_main.cc +++ b/runtime/engine/cls/nnet/panns_nnet_main.cc @@ -14,6 +14,7 @@ #include #include + #include "base/flags.h" #include "cls/nnet/panns_interface.h" diff --git a/runtime/engine/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt index a2f56f7f..405479ae 100644 --- a/runtime/engine/common/CMakeLists.txt +++ b/runtime/engine/common/CMakeLists.txt @@ -12,4 +12,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/frontend add_subdirectory(frontend) add_library(common INTERFACE) -add_definitions(common base utils kaldi-matrix frontend) \ No newline at end of file +target_link_libraries(common INTERFACE base utils kaldi-matrix frontend) +install(TARGETS base DESTINATION lib) +install(TARGETS utils DESTINATION lib) +install(TARGETS kaldi-matrix DESTINATION lib) +install(TARGETS frontend DESTINATION lib) \ No newline at end of file diff --git a/runtime/engine/common/base/CMakeLists.txt b/runtime/engine/common/base/CMakeLists.txt index a49b78bd..f4171a18 100644 --- a/runtime/engine/common/base/CMakeLists.txt +++ b/runtime/engine/common/base/CMakeLists.txt @@ -36,6 +36,7 @@ if(ANDROID) glog_utils.cc ) add_library(base ${csrc}) + target_link_libraries(base gflags) else() # UNIX set(csrc) add_library(base INTERFACE) diff --git a/runtime/engine/common/base/basic_types.h b/runtime/engine/common/base/basic_types.h index c7fdc924..2b15a61f 100644 --- a/runtime/engine/common/base/basic_types.h +++ b/runtime/engine/common/base/basic_types.h @@ -28,7 +28,7 @@ typedef int int32; // NOLINT #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) typedef long int64; // NOLINT #else -typedef long long int64; // NOLINT +typedef long long int64; // NOLINT #endif typedef unsigned char uint8; // NOLINT diff --git a/runtime/engine/common/base/common.h b/runtime/engine/common/base/common.h index 17560102..b31fc53e 100644 --- a/runtime/engine/common/base/common.h +++ b/runtime/engine/common/base/common.h @@ -50,4 +50,5 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" \ No newline at end of file +#include "utils/math.h" +#include "utils/timer.h" \ No newline at end of file diff --git a/runtime/engine/common/base/config.h b/runtime/engine/common/base/config.h index c59c3ab8..c8eae5e2 100644 --- a/runtime/engine/common/base/config.h +++ b/runtime/engine/common/base/config.h @@ -10,11 +10,14 @@ using namespace std; #pragma once +#ifdef _MSC_VER #pragma region ParseIniFile +#endif + /* -* \brief Generic configuration Class -* -*/ + * \brief Generic configuration Class + * + */ class Config { // Data protected: @@ -32,7 +35,7 @@ class Config { std::string comment = "#"); Config(); template - T Read(const std::string& in_key) const; //! template @@ -335,4 +338,6 @@ void Config::ReadFile(string filename, string delimiter, string comment) { in >> (*this); } +#ifdef _MSC_VER #pragma endregion ParseIniFIle +#endif diff --git a/runtime/engine/common/base/log_impl.cc b/runtime/engine/common/base/log_impl.cc index 8286f1e7..d8295590 100644 --- a/runtime/engine/common/base/log_impl.cc +++ b/runtime/engine/common/base/log_impl.cc @@ -29,9 +29,9 @@ LogMessage::LogMessage(const char* file, bool out_to_file /* = false */) : level_(level), verbose_(verbose), out_to_file_(out_to_file) { if (FLAGS_logtostderr == 0) { - stream_ = std::shared_ptr(&std::cout); + stream_ = static_cast(&std::cout); } else if (FLAGS_logtostderr == 1) { - stream_ = std::shared_ptr(&std::cerr); + stream_ = static_cast(&std::cerr); } else if (out_to_file_) { // logfile lock_.lock(); @@ -46,11 +46,21 @@ LogMessage::~LogMessage() { lock_.unlock(); } - if (level_ == FATAL) { + if (verbose_ && level_ == FATAL) { std::abort(); } } +std::ostream* LogMessage::nullstream() { + thread_local static std::ofstream os; + thread_local static bool flag_set = false; + if (!flag_set) { + os.setstate(std::ios_base::badbit); + flag_set = true; + } + return &os; +} + void LogMessage::init(const char* file, int line) { time_t t = time(0); char tmp[100]; @@ -73,30 +83,20 @@ void LogMessage::init(const char* file, int line) { std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid); } - std::ofstream ofs; + thread_local static std::ofstream ofs; if (level_ == DEBUG) { - stream_ = std::make_shared( - s_debug_logfile_.c_str(), std::ios::out | std::ios::app); - // ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app); + ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app); } else if (level_ == INFO) { - // ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); - stream_ = std::make_shared( - s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app); } else if (level_ == WARNING) { - // ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); - stream_ = std::make_shared( - s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); } else if (level_ == ERROR) { - // ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app); - stream_ = std::make_shared( - s_error_logfile_.c_str(), std::ios::out | std::ios::app); + ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app); } else { - // ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); - stream_ = std::make_shared( - s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); + ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); } - // stream_ = &ofs; + stream_ = &ofs; stream() << tmp << " " << file << " line " << line << "; "; stream() << std::flush; diff --git a/runtime/engine/common/base/log_impl.h b/runtime/engine/common/base/log_impl.h index 93573620..2cc96c45 100644 --- a/runtime/engine/common/base/log_impl.h +++ b/runtime/engine/common/base/log_impl.h @@ -18,6 +18,9 @@ #pragma once +#include +#include + #include #include #include @@ -25,9 +28,6 @@ #include #include -#include -#include - #include "base/common.h" #include "base/macros.h" #ifndef WITH_GLOG @@ -61,13 +61,15 @@ class LogMessage { ~LogMessage(); - std::ostream& stream() { return *stream_; } + std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); } private: void init(const char* file, int line); + std::ostream* nullstream(); private: - std::shared_ptr stream_; + std::ostream* stream_; + std::ostream* null_stream_; Severity level_; bool verbose_; bool out_to_file_; @@ -88,14 +90,16 @@ class LogMessage { } // namespace ppspeech -#ifndef NDEBUG -#define DLOG_DEBUG \ - ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::DEBUG, false) +#ifdef NDEBUG +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false) #else -#define DLOG_DEBUG \ - ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::DEBUG, true) -#endif - #define DLOG_INFO \ ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) #define DLOG_WARNING \ @@ -104,17 +108,30 @@ class LogMessage { ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) #define DLOG_FATAL \ ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) +#endif -#define DLOG_0 DLOG_DEBUG -#define DLOG_1 DLOG_INFO -#define DLOG_2 DLOG_WARNING -#define DLOG_3 DLOG_ERROR -#define DLOG_4 DLOG_FATAL -#define LOG(level) DLOG_##level.stream() +#define LOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define LOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define LOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define LOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) -#define VLOG(verboselevel) LOG(verboselevel) +#define LOG_0 LOG_DEBUG +#define LOG_1 LOG_INFO +#define LOG_2 LOG_WARNING +#define LOG_3 LOG_ERROR +#define LOG_4 LOG_FATAL + +#define LOG(level) LOG_##level.stream() + +#define DLOG(level) DLOG_##level.stream() + +#define VLOG(verboselevel) LOG(verboselevel) #define CHECK(exp) \ ppspeech::log::LogMessage( \ diff --git a/runtime/engine/common/frontend/CMakeLists.txt b/runtime/engine/common/frontend/CMakeLists.txt index 4ff3117c..5d78e7ea 100644 --- a/runtime/engine/common/frontend/CMakeLists.txt +++ b/runtime/engine/common/frontend/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(kaldi-native-fbank-core mel-computations.cc rfft.cc ) +target_link_libraries(kaldi-native-fbank-core PUBLIC utils base) add_library(frontend STATIC cmvn.cc @@ -15,7 +16,7 @@ add_library(frontend STATIC assembler.cc wave-reader.cc ) -target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils) +target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils base) set(BINS compute_fbank_main @@ -24,5 +25,6 @@ set(BINS foreach(bin_name IN LISTS BINS) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util gflags extern_glog) + # https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 + target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util gflags Threads::Threads extern_glog) endforeach() \ No newline at end of file diff --git a/runtime/engine/common/frontend/assembler.cc b/runtime/engine/common/frontend/assembler.cc index 487951cd..70e1a43e 100644 --- a/runtime/engine/common/frontend/assembler.cc +++ b/runtime/engine/common/frontend/assembler.cc @@ -17,9 +17,8 @@ namespace ppspeech { using kaldi::BaseFloat; -using std::vector; -using std::vector; using std::unique_ptr; +using std::vector; Assembler::Assembler(AssemblerOptions opts, unique_ptr base_extractor) { diff --git a/runtime/engine/common/frontend/fftsg.c b/runtime/engine/common/frontend/fftsg.c index ec8217a2..30b81604 100644 --- a/runtime/engine/common/frontend/fftsg.c +++ b/runtime/engine/common/frontend/fftsg.c @@ -821,12 +821,12 @@ void cftfsub(int n, double *a, int *ip, int nw, double *w) { } else #endif /* USE_CDFT_THREADS */ if (n > 512) { - cftrec4(n, a, nw, w); - } else if (n > 128) { - cftleaf(n, 1, a, nw, w); - } else { - cftfx41(n, a, nw, w); - } + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } bitrv2(n, ip, a); } else if (n == 32) { cftf161(a, &w[nw - 8]); @@ -868,12 +868,12 @@ void cftbsub(int n, double *a, int *ip, int nw, double *w) { } else #endif /* USE_CDFT_THREADS */ if (n > 512) { - cftrec4(n, a, nw, w); - } else if (n > 128) { - cftleaf(n, 1, a, nw, w); - } else { - cftfx41(n, a, nw, w); - } + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } bitrv2conj(n, ip, a); } else if (n == 32) { cftf161(a, &w[nw - 8]); diff --git a/runtime/engine/common/frontend/rfft.cc b/runtime/engine/common/frontend/rfft.cc index 8cdb634f..9ce6a172 100644 --- a/runtime/engine/common/frontend/rfft.cc +++ b/runtime/engine/common/frontend/rfft.cc @@ -17,12 +17,13 @@ */ #include "frontend/rfft.h" -#include "base/log.h" #include #include #include +#include "base/log.h" + // see fftsg.c #ifdef __cplusplus extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w); diff --git a/runtime/engine/common/frontend/wave-reader.cc b/runtime/engine/common/frontend/wave-reader.cc index b64dcc9e..e94aafef 100644 --- a/runtime/engine/common/frontend/wave-reader.cc +++ b/runtime/engine/common/frontend/wave-reader.cc @@ -19,6 +19,8 @@ // See the Apache 2 License for the specific language governing permissions and // limitations under the License. +#include "frontend/wave-reader.h" + #include #include #include @@ -27,7 +29,6 @@ #include "base/kaldi-error.h" #include "base/kaldi-utils.h" -#include "frontend/wave-reader.h" namespace kaldi { @@ -243,10 +244,9 @@ void WaveInfo::Read(std::istream &is) { << ", data chunk size: " << data_chunk_size << ". Assume 'stream mode' (reading data to EOF)."; - if (!is_stream_mode && - std::abs(static_cast(riff_chunk_read) + - static_cast(data_chunk_size) - - static_cast(riff_chunk_size)) > 1) { + if (!is_stream_mode && std::abs(static_cast(riff_chunk_read) + + static_cast(data_chunk_size) - + static_cast(riff_chunk_size)) > 1) { // We allow the size to be off by one without warning, because there is // a // weirdness in the format of RIFF files that means that the input may diff --git a/runtime/engine/common/matrix/kaldi-matrix.h b/runtime/engine/common/matrix/kaldi-matrix.h index c082a731..d614f36f 100644 --- a/runtime/engine/common/matrix/kaldi-matrix.h +++ b/runtime/engine/common/matrix/kaldi-matrix.h @@ -590,7 +590,7 @@ class MatrixBase { * SpMatrix and use Eig() function there, which uses eigenvalue * decomposition * directly rather than SVD. - */ + */ /// stream read. /// Use instead of stream<<*this, if you want to add to existing contents. diff --git a/runtime/engine/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc index 1d0b55b9..3ab9a7ff 100644 --- a/runtime/engine/common/matrix/kaldi-vector.cc +++ b/runtime/engine/common/matrix/kaldi-vector.cc @@ -24,8 +24,10 @@ // limitations under the License. #include "matrix/kaldi-vector.h" + #include #include + #include "matrix/kaldi-matrix.h" namespace kaldi { diff --git a/runtime/engine/common/matrix/matrix-common.h b/runtime/engine/common/matrix/matrix-common.h index 512beb20..e915db0a 100644 --- a/runtime/engine/common/matrix/matrix-common.h +++ b/runtime/engine/common/matrix/matrix-common.h @@ -90,7 +90,7 @@ typedef uint32 UnsignedMatrixIndexT; // typedef size_t MatrixIndexT; // typedef ssize_t SignedMatrixIndexT; // typedef size_t UnsignedMatrixIndexT; -} +} // namespace kaldi #endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/runtime/engine/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt index eb3c7197..14733648 100644 --- a/runtime/engine/common/utils/CMakeLists.txt +++ b/runtime/engine/common/utils/CMakeLists.txt @@ -5,6 +5,7 @@ set(csrc math.cc strings.cc audio_process.cc + timer.cc ) add_library(utils ${csrc}) diff --git a/runtime/engine/common/utils/timer.cc b/runtime/engine/common/utils/timer.cc new file mode 100644 index 00000000..ff43cd04 --- /dev/null +++ b/runtime/engine/common/utils/timer.cc @@ -0,0 +1,63 @@ +// Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) +// Mobvoi Inc. (authors: Fangjun Kuang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "common/utils/timer.h" + +namespace ppspeech{ + +struct TimerImpl{ + TimerImpl() = default; + virtual ~TimerImpl() = default; + virtual void Reset() = 0; + // time in seconds + virtual double Elapsed() = 0; +}; + +class CpuTimerImpl : public TimerImpl { + public: + CpuTimerImpl() { Reset(); } + + using high_resolution_clock = std::chrono::high_resolution_clock; + + void Reset() override { begin_ = high_resolution_clock::now(); } + + // time in seconds + double Elapsed() override { + auto end = high_resolution_clock::now(); + auto dur = + std::chrono::duration_cast(end - begin_); + return dur.count() / 1000000.0; + } + + private: + high_resolution_clock::time_point begin_; +}; + +Timer::Timer() { + impl_ = std::make_unique(); +} + +Timer::~Timer() = default; + +void Timer::Reset() const { impl_->Reset(); } + +double Timer::Elapsed() const { return impl_->Elapsed(); } + + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/timer.h b/runtime/engine/common/utils/timer.h new file mode 100644 index 00000000..6f4ae1f8 --- /dev/null +++ b/runtime/engine/common/utils/timer.h @@ -0,0 +1,39 @@ +// Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) +// Mobvoi Inc. (authors: Fangjun Kuang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ppspeech { + +struct TimerImpl; + +class Timer { + public: + Timer(); + ~Timer(); + + void Reset() const; + + // time in seconds + double Elapsed() const; + + private: + std::unique_ptr impl_; +}; + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/CMakeLists.txt b/runtime/engine/vad/CMakeLists.txt index 4e9f448c..f61c5a9a 100644 --- a/runtime/engine/vad/CMakeLists.txt +++ b/runtime/engine/vad/CMakeLists.txt @@ -1,5 +1,7 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR}/../ +) +add_subdirectory(nnet) -set(bin_name silero_vad_main) -add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc vad.cc) -target_link_libraries(${bin_name} ${FASTDEPLOY_LIBS} gflags extern_glog) +add_subdirectory(interface) \ No newline at end of file diff --git a/runtime/engine/vad/wav.h b/runtime/engine/vad/frontend/wav.h similarity index 99% rename from runtime/engine/vad/wav.h rename to runtime/engine/vad/frontend/wav.h index 6d1a6f72..f9b7bee2 100644 --- a/runtime/engine/vad/wav.h +++ b/runtime/engine/vad/frontend/wav.h @@ -17,6 +17,8 @@ #include #include #include + +#include #include namespace wav { diff --git a/runtime/engine/vad/interface/CMakeLists.txt b/runtime/engine/vad/interface/CMakeLists.txt new file mode 100644 index 00000000..30700027 --- /dev/null +++ b/runtime/engine/vad/interface/CMakeLists.txt @@ -0,0 +1,25 @@ +set(srcs + vad_interface.cc +) + +add_library(pps_vad_interface ${srcs}) +target_link_libraries(pps_vad_interface PUBLIC pps_vad extern_glog) + + +set(bin_name vad_interface_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad_interface) +# set_target_properties(${bin_name} PROPERTIES PUBLIC_HEADER "vad_interface.h;../frontend/wav.h") + + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad_interface DESTINATION lib) +install(FILES vad_interface.h DESTINATION include/${DEST_DIR}) + +install(TARGETS vad_interface_main + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + PUBLIC_HEADER DESTINATION include/${DEST_DIR} +) +install(FILES vad_interface_main.cc DESTINATION demo/${DEST_DIR}) \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.cc b/runtime/engine/vad/interface/vad_interface.cc new file mode 100644 index 00000000..4c3877ff --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "vad/interface/vad_interface.h" + +#include "common/base/config.h" +#include "vad/nnet/vad.h" + + +PPSHandle_t PPSVadCreateInstance(const char* conf_path) { + Config conf(conf_path); + ppspeech::VadNnetConf nnet_conf; + nnet_conf.sr = conf.Read("sr", 16000); + nnet_conf.frame_ms = conf.Read("frame_ms", 32); + nnet_conf.threshold = conf.Read("threshold", 0.45f); + nnet_conf.beam = conf.Read("beam", 0.15f); + nnet_conf.min_silence_duration_ms = + conf.Read("min_silence_duration_ms", 200); + nnet_conf.speech_pad_left_ms = conf.Read("speech_pad_left_ms", 0); + nnet_conf.speech_pad_right_ms = conf.Read("speech_pad_right_ms", 0); + + nnet_conf.model_file_path = conf.Read("model_path", std::string("")); + nnet_conf.param_file_path = conf.Read("param_path", std::string("")); + nnet_conf.num_cpu_thread = conf.Read("num_cpu_thread", 1); + + ppspeech::Vad* model = new ppspeech::Vad(nnet_conf.model_file_path); + + // custom config, but must be set before init + model->SetConfig(nnet_conf); + model->Init(); + + return static_cast(model); +} + + +int PPSVadDestroyInstance(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model != nullptr) { + delete model; + model = nullptr; + } + return 0; +} + +int PPSVadChunkSizeSamples(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + + return model->WindowSizeSamples(); +} + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return PPS_VAD_ILLEGAL; + } + + std::vector chunk_in(chunk, chunk + num_element); + if (!model->ForwardChunk(chunk_in)) { + printf("forward chunk failed\n"); + return PPS_VAD_ILLEGAL; + } + ppspeech::Vad::State s = model->Postprocess(); + PPSVadState_t ret = (PPSVadState_t)s; + return ret; +} + +int PPSVadReset(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + model->Reset(); + return 0; +} \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.h b/runtime/engine/vad/interface/vad_interface.h new file mode 100644 index 00000000..5d7ca709 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface_main.cc b/runtime/engine/vad/interface/vad_interface_main.cc new file mode 100644 index 00000000..16059c41 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface_main.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include + +#include "common/base/common.h" +#include "vad/frontend/wav.h" +#include "vad/interface/vad_interface.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: vad_interface_main path/to/config path/to/audio " + "run_option, " + "e.g ./vad_interface_main config sample.wav" + << std::endl; + return -1; + } + + std::string config_path = argv[1]; + std::string audio_file = argv[2]; + + PPSHandle_t handle = PPSVadCreateInstance(config_path.c_str()); + + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(audio_file); + auto sr = wav_reader.sample_rate(); + CHECK(sr == 16000) << " sr is " << sr << " expect 16000"; + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + ppspeech::Timer timer; + int window_size_samples = PPSVadChunkSizeSamples(handle); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + auto current_chunk_size = end - start; + + std::vector r{&inputWav[0] + start, &inputWav[0] + end}; + assert(r.size() == static_cast(current_chunk_size)); + + PPSVadState_t s = PPSVadFeedForward(handle, r.data(), r.size()); + std::cout << s << " "; + } + std::cout << std::endl; + + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + + PPSVadReset(handle); + + return 0; +} diff --git a/runtime/engine/vad/nnet/CMakeLists.txt b/runtime/engine/vad/nnet/CMakeLists.txt new file mode 100644 index 00000000..22c9f760 --- /dev/null +++ b/runtime/engine/vad/nnet/CMakeLists.txt @@ -0,0 +1,16 @@ +set(srcs + vad.cc +) + +add_library(pps_vad ${srcs}) +target_link_libraries(pps_vad PUBLIC ${FASTDEPLOY_LIBS} common extern_glog) + + +set(bin_name vad_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad) + + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad DESTINATION lib) +install(TARGETS extern_glog DESTINATION lib) \ No newline at end of file diff --git a/runtime/engine/vad/vad.cc b/runtime/engine/vad/nnet/vad.cc similarity index 80% rename from runtime/engine/vad/vad.cc rename to runtime/engine/vad/nnet/vad.cc index 7630b98d..0b77e632 100644 --- a/runtime/engine/vad/vad.cc +++ b/runtime/engine/vad/nnet/vad.cc @@ -1,4 +1,5 @@ // Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,20 +12,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "vad.h" +#include "vad/nnet/vad.h" + #include #include +#include "common/base/common.h" + -#ifdef NDEBUG -#define LOG_DEBUG \ - ::fastdeploy::FDLogger(true, "[DEBUG]") << __REL_FILE__ << "(" << __LINE__ \ - << ")::" << __FUNCTION__ << "\t" -#else -#define LOG_DEBUG \ - ::fastdeploy::FDLogger(false, "[DEBUG]") \ - << __REL_FILE__ << "(" << __LINE__ << ")::" << __FUNCTION__ << "\t" -#endif +namespace ppspeech { Vad::Vad(const std::string& model_file, const fastdeploy::RuntimeOption& @@ -48,18 +44,30 @@ Vad::Vad(const std::string& model_file, } void Vad::Init() { - std::call_once(init_, [&]() { initialized = Initialize(); }); + std::lock_guard lock(init_lock_); + Initialize(); } std::string Vad::ModelName() const { return "VAD"; } -void Vad::SetConfig(int sr, - int frame_ms, - float threshold, - int min_silence_duration_ms, - int speech_pad_left_ms, - int speech_pad_right_ms) { - if (initialized) { +void Vad::SetConfig(const VadNnetConf conf) { + SetConfig(conf.sr, + conf.frame_ms, + conf.threshold, + conf.beam, + conf.min_silence_duration_ms, + conf.speech_pad_left_ms, + conf.speech_pad_right_ms); +} + +void Vad::SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms) { + if (initialized_) { fastdeploy::FDERROR << "SetConfig must be called before init" << std::endl; throw std::runtime_error("SetConfig must be called before init"); @@ -67,6 +75,7 @@ void Vad::SetConfig(int sr, sample_rate_ = sr; sr_per_ms_ = sr / 1000; threshold_ = threshold; + beam_ = beam; frame_ms_ = frame_ms; min_silence_samples_ = min_silence_duration_ms * sr_per_ms_; speech_pad_left_samples_ = speech_pad_left_ms * sr_per_ms_; @@ -76,8 +85,8 @@ void Vad::SetConfig(int sr, window_size_samples_ = frame_ms * sr_per_ms_; current_chunk_size_ = window_size_samples_; - fastdeploy::FDINFO << "sr=" << sr << " threshold=" << threshold - << " frame_ms=" << frame_ms + fastdeploy::FDINFO << "sr=" << sr_per_ms_ << " threshold=" << threshold_ + << " beam=" << beam_ << " frame_ms=" << frame_ms_ << " min_silence_duration_ms=" << min_silence_duration_ms << " speech_pad_left_ms=" << speech_pad_left_ms << " speech_pad_right_ms=" << speech_pad_right_ms; @@ -114,12 +123,17 @@ bool Vad::Initialize() { Reset(); + // InitRuntime if (!InitRuntime()) { fastdeploy::FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + + initialized_ = true; + + fastdeploy::FDINFO << "init done."; return true; } @@ -162,8 +176,8 @@ const Vad::State& Vad::Postprocess() { if (outputProb_ < threshold_ && !triggerd_) { // 1. Silence - LOG_DEBUG << "{ silence: " << 1.0 * current_sample_ / sample_rate_ - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) << "{ silence: " << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; states_.emplace_back(Vad::State::SIL); } else if (outputProb_ >= threshold_ && !triggerd_) { // 2. Start @@ -172,27 +186,28 @@ const Vad::State& Vad::Postprocess() { current_sample_ - current_chunk_size_ - speech_pad_left_samples_; float start_sec = 1.0 * speech_start_ / sample_rate_; speakStart_.emplace_back(start_sec); - LOG_DEBUG << "{ speech start: " << start_sec - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) << "{ speech start: " << start_sec + << " s; prob: " << outputProb_ << " }"; states_.emplace_back(Vad::State::START); - } else if (outputProb_ >= threshold_ - 0.15 && triggerd_) { + } else if (outputProb_ >= threshold_ - beam_ && triggerd_) { // 3. Continue if (temp_end_ != 0) { // speech prob relaxation, speech continues again - LOG_DEBUG << "{ speech fake end(sil < min_silence_ms) to continue: " - << 1.0 * current_sample_ / sample_rate_ - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) + << "{ speech fake end(sil < min_silence_ms) to continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; temp_end_ = 0; } else { // speech prob relaxation, keep tracking speech - LOG_DEBUG << "{ speech continue: " - << 1.0 * current_sample_ / sample_rate_ - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) << "{ speech continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; } states_.emplace_back(Vad::State::SPEECH); - } else if (outputProb_ < threshold_ - 0.15 && triggerd_) { + } else if (outputProb_ < threshold_ - beam_ && triggerd_) { // 4. End if (temp_end_ == 0) { temp_end_ = current_sample_; @@ -201,9 +216,9 @@ const Vad::State& Vad::Postprocess() { // check possible speech end if (current_sample_ - temp_end_ < min_silence_samples_) { // a. silence < min_slience_samples, continue speaking - LOG_DEBUG << "{ speech fake end(sil < min_silence_ms): " - << 1.0 * current_sample_ / sample_rate_ - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) << "{ speech fake end(sil < min_silence_ms): " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; states_.emplace_back(Vad::State::SIL); } else { // b. silence >= min_slience_samples, end speaking @@ -212,8 +227,8 @@ const Vad::State& Vad::Postprocess() { triggerd_ = false; auto end_sec = 1.0 * speech_end_ / sample_rate_; speakEnd_.emplace_back(end_sec); - LOG_DEBUG << "{ speech end: " << end_sec - << " s; prob: " << outputProb_ << " }"; + DLOG(INFO) << "{ speech end: " << end_sec + << " s; prob: " << outputProb_ << " }"; states_.emplace_back(Vad::State::END); } } @@ -303,4 +318,6 @@ std::ostream& operator<<(std::ostream& os, const Vad::State& s) { break; } return os; -} \ No newline at end of file +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/vad.h b/runtime/engine/vad/nnet/vad.h similarity index 78% rename from runtime/engine/vad/vad.h rename to runtime/engine/vad/nnet/vad.h index 6eed7d1c..de557ec6 100644 --- a/runtime/engine/vad/vad.h +++ b/runtime/engine/vad/nnet/vad.h @@ -1,4 +1,5 @@ // Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,33 +12,59 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #pragma once #include #include #include -#include "./wav.h" + #include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/runtime.h" +#include "vad/frontend/wav.h" + +namespace ppspeech { + +struct VadNnetConf { + // wav + int sr; + int frame_ms; + float threshold; + float beam; + int min_silence_duration_ms; + int speech_pad_left_ms; + int speech_pad_right_ms; + + // model + std::string model_file_path; + std::string param_file_path; + std::string dict_file_path; + int num_cpu_thread; // 1 thred + std::string backend; // ort,lite, etc. +}; class Vad : public fastdeploy::FastDeployModel { public: - enum class State { SIL = 0, START, SPEECH, END }; + enum class State { ILLEGAL = 0, SIL, START, SPEECH, END }; friend std::ostream& operator<<(std::ostream& os, const Vad::State& s); Vad(const std::string& model_file, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption()); + virtual ~Vad() {} + void Init(); void Reset(); - void SetConfig(int sr, - int frame_ms, - float threshold, - int min_silence_duration_ms, - int speech_pad_left_ms, - int speech_pad_right_ms); + void SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms); + void SetConfig(const VadNnetConf conf); bool ForwardChunk(std::vector& chunk); @@ -78,7 +105,9 @@ class Vad : public fastdeploy::FastDeployModel { bool Initialize(); private: - std::once_flag init_; + std::mutex init_lock_; + bool initialized_{false}; + // input and output std::vector inputTensors_; std::vector outputTensors_; @@ -103,6 +132,7 @@ class Vad : public fastdeploy::FastDeployModel { int sample_rate_ = 16000; int frame_ms_ = 32; // 32, 64, 96 for 16k float threshold_ = 0.5f; + float beam_ = 0.15f; int64_t window_size_samples_; // support 256 512 768 for 8k; 512 1024 1536 // for 16k. @@ -122,3 +152,5 @@ class Vad : public fastdeploy::FastDeployModel { const std::vector sr_node_dims_ = {1}; const std::vector hc_node_dims_ = {2, 1, 64}; }; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/silero_vad_main.cc b/runtime/engine/vad/nnet/vad_nnet_main.cc similarity index 58% rename from runtime/engine/vad/silero_vad_main.cc rename to runtime/engine/vad/nnet/vad_nnet_main.cc index 7fb52406..7b89d1af 100644 --- a/runtime/engine/vad/silero_vad_main.cc +++ b/runtime/engine/vad/nnet/vad_nnet_main.cc @@ -1,11 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "vad.h" + +#include "common/base/common.h" +#include "vad/nnet/vad.h" int main(int argc, char* argv[]) { if (argc < 3) { - std::cout << "Usage: infer_onnx_silero_vad path/to/model path/to/audio " + std::cout << "Usage: vad_nnet_main path/to/model path/to/audio " "run_option, " - "e.g ./infer_onnx_silero_vad silero_vad.onnx sample.wav" + "e.g ./vad_nnet_main silero_vad.onnx sample.wav" << std::endl; return -1; } @@ -14,9 +29,9 @@ int main(int argc, char* argv[]) { std::string audio_file = argv[2]; int sr = 16000; - Vad vad(model_file); + ppspeech::Vad vad(model_file); // custom config, but must be set before init - vad.SetConfig(sr, 32, 0.45f, 200, 0, 0); + vad.SetConfig(sr, 32, 0.5f, 0.15, 200, 0, 0); vad.Init(); std::vector inputWav; // [0, 1] @@ -30,6 +45,7 @@ int main(int argc, char* argv[]) { inputWav[i] = wav_reader.data()[i] / 32768; } + ppspeech::Timer timer; int window_size_samples = vad.WindowSizeSamples(); for (int64_t j = 0; j < num_samples; j += window_size_samples) { auto start = j; @@ -39,7 +55,7 @@ int main(int argc, char* argv[]) { auto current_chunk_size = end - start; std::vector r{&inputWav[0] + start, &inputWav[0] + end}; - assert(r.size() == current_chunk_size); + assert(r.size() == static_cast(current_chunk_size)); if (!vad.ForwardChunk(r)) { std::cerr << "Failed to inference while using model:" @@ -47,11 +63,14 @@ int main(int argc, char* argv[]) { return false; } - Vad::State s = vad.Postprocess(); + ppspeech::Vad::State s = vad.Postprocess(); std::cout << s << " "; } std::cout << std::endl; + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + std::vector> result = vad.GetResult(); for (auto& res : result) { std::cout << "speak start: " << res["start"] diff --git a/runtime/examples/silero_vad/README.md b/runtime/examples/silero_vad/README.md deleted file mode 100644 index f032be86..00000000 --- a/runtime/examples/silero_vad/README.md +++ /dev/null @@ -1,121 +0,0 @@ -English | [简体中文](README_CN.md) - -# Silero VAD Deployment Example - -This directory provides examples that `infer_onnx_silero_vad` fast finishes the deployment of VAD models on CPU/GPU. - -Before deployment, two steps require confirmation. - -- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../docs/en/build_and_install/download_prebuilt_libraries.md). -- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../docs/en/build_and_install/download_prebuilt_libraries.md). - -Taking VAD inference on Linux as an example, the compilation test can be completed by executing the following command in this directory. - -```bash -mkdir build -cd build -# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above -wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz -tar xvf fastdeploy-linux-x64-x.x.x.tgz -cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x -make -j - -# Download the VAD model file and test audio. After decompression, place the model and test audio in the infer_onnx_silero_vad.cc peer directory -wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz -wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad_sample.wav - -# inference -./infer_onnx_silero_vad ../silero_vad.onnx ../silero_vad_sample.wav -``` - -- The above command works for Linux or MacOS. Refer to: - - [How to use FastDeploy C++ SDK in Windows](../../../../docs/en/faq/use_sdk_on_windows.md) for SDK use-pattern in Windows - -## VAD C++ Interface - -### Vad Class - -```c++ -Vad::Vad(const std::string& model_file, - const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption()) -``` - -**Parameter** - -> * **model_file**(str): Model file path -> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default. (use the default configuration) - -### setAudioCofig function - -**Must be called before the `init` function** - -```c++ -void Vad::setAudioCofig(int sr, int frame_ms, float threshold, int min_silence_duration_ms, int speech_pad_ms); -``` - -**Parameter** - -> * **sr**(int): sampling rate -> * **frame_ms**(int): The length of each detection frame, and it is used to calculate the detection window size -> * **threshold**(float): Result probability judgment threshold -> * **min_silence_duration_ms**(int): The threshold used to calculate whether it is silence -> * **speech_pad_ms**(int): Used to calculate the end time of the speech - -### init function - -Used to initialize audio-related parameters. - -```c++ -void Vad::init(); -``` - -### loadAudio function - -Load audio. - -```c++ -void Vad::loadAudio(const std::string& wavPath) -``` - -**Parameter** - -> * **wavPath**(str): Audio file path - -### Predict function - -Used to start model reasoning. - -```c++ -bool Vad::Predict(); -``` - -### getResult function - -**Used to obtain reasoning results** - -```c++ -std::vector> Vad::getResult( - float removeThreshold = 1.6, float expandHeadThreshold = 0.32, float expandTailThreshold = 0, - float mergeThreshold = 0.3); -``` - -**Parameter** - -> * **removeThreshold**(float): Discard result fragment threshold; If some recognition results are too short, they will be discarded according to this threshold -> * **expandHeadThreshold**(float): Offset at the beginning of the segment; The recognized start time may be too close to the voice part, so move forward the start time accordingly -> * **expandTailThreshold**(float): Offset at the end of the segment; The recognized end time may be too close to the voice part, so the end time is moved back accordingly -> * **mergeThreshold**(float): Some result segments are very close and can be combined into one, and the vocal segments can be combined accordingly - -**The output result format is**`std::vector>` - -> Output a list, each element is a speech fragment -> -> Each clip can use 'start' to get the start time and 'end' to get the end time - -### Tips - -1. `The setAudioCofig`function must be called before the `init` function -2. The sampling rate of the input audio file must be consistent with that set in the code - -- [Model Description](../) -- [How to switch the model inference backend engine](../../../../docs/en/faq/how_to_change_backend.md) diff --git a/runtime/examples/silero_vad/README_CN.md b/runtime/examples/silero_vad/README_CN.md deleted file mode 100644 index c45d9896..00000000 --- a/runtime/examples/silero_vad/README_CN.md +++ /dev/null @@ -1,119 +0,0 @@ -[English](README.md) | 简体中文 -# Silero VAD 部署示例 - -本目录下提供`infer_onnx_silero_vad`快速完成 Silero VAD 模型在CPU/GPU。 - -在部署前,需确认以下两个步骤 - -- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) -- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) - -以Linux上 VAD 推理为例,在本目录执行如下命令即可完成编译测试。 - -```bash -mkdir build -cd build -# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用 -wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz -tar xvf fastdeploy-linux-x64-x.x.x.tgz -cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x -make -j - -# 下载 VAD 模型文件和测试音频,解压后将模型和测试音频放置在与 infer_onnx_silero_vad.cc 同级目录下 -wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz -wget https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad_sample.wav - -# 推理 -./infer_onnx_silero_vad ../silero_vad.onnx ../silero_vad_sample.wav -``` - -以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: -- [如何在Windows中使用FastDeploy C++ SDK](../../../../docs/cn/faq/use_sdk_on_windows.md) - -## VAD C++ 接口 -### Vad 类 - -```c++ -Vad::Vad(const std::string& model_file, - const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption()) -``` - -**参数** - -> * **model_file**(str): 模型文件路径 -> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 - -### setAudioCofig 函数 - -**必须在`init`函数前调用** - -```c++ -void Vad::setAudioCofig(int sr, int frame_ms, float threshold, int min_silence_duration_ms, int speech_pad_ms); -``` - -**参数** - -> * **sr**(int): 采样率 -> * **frame_ms**(int): 每次检测帧长,用于计算检测窗口大小 -> * **threshold**(float): 结果概率判断阈值 -> * **min_silence_duration_ms**(int): 用于计算判断是否是 silence 的阈值 -> * **speech_pad_ms**(int): 用于计算 speach 结束时刻 - -### init 函数 - -用于初始化音频相关参数 - -```c++ -void Vad::init(); -``` - -### loadAudio 函数 - -加载音频 - -```c++ -void Vad::loadAudio(const std::string& wavPath) -``` - -**参数** - -> * **wavPath**(str): 音频文件路径 - -### Predict 函数 - -用于开始模型推理 - -```c++ -bool Vad::Predict(); -``` - -### getResult 函数 - -**用于获取推理结果** - -```c++ -std::vector> Vad::getResult( - float removeThreshold = 1.6, float expandHeadThreshold = 0.32, float expandTailThreshold = 0, - float mergeThreshold = 0.3); -``` - -**参数** - -> * **removeThreshold**(float): 丢弃结果片段阈值;部分识别结果太短则根据此阈值丢弃 -> * **expandHeadThreshold**(float): 结果片段开始时刻偏移;识别到的开始时刻可能过于贴近发声部分,因此据此前移开始时刻 -> * **expandTailThreshold**(float): 结果片段结束时刻偏移;识别到的结束时刻可能过于贴近发声部分,因此据此后移结束时刻 -> * **mergeThreshold**(float): 有的结果片段十分靠近,可以合并成一个,据此合并发声片段 - -**输出结果格式为**`std::vector>` - -> 输出一个列表,每个元素是一个讲话片段 -> -> 每个片段可以用 'start' 获取到开始时刻,用 'end' 获取到结束时刻 - -### 提示 - -1. `setAudioCofig`函数必须在`init`函数前调用 -2. 输入的音频文件的采样率必须与代码中设置的保持一致 - -- [模型介绍](../) -- [如何切换模型推理后端引擎](../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/runtime/examples/silero_vad/local/decode.sh b/runtime/examples/silero_vad/local/decode.sh deleted file mode 100755 index e69de29b..00000000 diff --git a/runtime/examples/silero_vad/path.sh b/runtime/examples/silero_vad/path.sh deleted file mode 100644 index ad3a7358..00000000 --- a/runtime/examples/silero_vad/path.sh +++ /dev/null @@ -1,18 +0,0 @@ -# This contains the locations of binarys build required for running the examples. - -unset GREP_OPTIONS - -ENGINE_ROOT=$PWD/../../../ -ENGINE_BUILD=$ENGINE_ROOT/build/engine/asr - -ENGINE_TOOLS=$ENGINE_ROOT/tools -TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin - -[ -d $ENGINE_BUILD ] || { echo "Error: 'build/runtime' directory not found. please ensure that the project build successfully"; } - -export LC_AL=C - -export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/decoder:$ENGINE_BUILD/../common/frontend/audio:$ENGINE_BUILD/recognizer - -#PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") -export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/runtime/examples/u2pp_ol/wenetspeech/path.sh b/runtime/examples/u2pp_ol/wenetspeech/path.sh index ad3a7358..544e2048 100644 --- a/runtime/examples/u2pp_ol/wenetspeech/path.sh +++ b/runtime/examples/u2pp_ol/wenetspeech/path.sh @@ -3,7 +3,7 @@ unset GREP_OPTIONS ENGINE_ROOT=$PWD/../../../ -ENGINE_BUILD=$ENGINE_ROOT/build/engine/asr +ENGINE_BUILD=$ENGINE_ROOT/build/Linux/x86_64/engine/asr ENGINE_TOOLS=$ENGINE_ROOT/tools TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin diff --git a/runtime/examples/silero_vad/.gitignore b/runtime/examples/vad/.gitignore similarity index 100% rename from runtime/examples/silero_vad/.gitignore rename to runtime/examples/vad/.gitignore diff --git a/runtime/examples/vad/README.md b/runtime/examples/vad/README.md new file mode 100644 index 00000000..b521063b --- /dev/null +++ b/runtime/examples/vad/README.md @@ -0,0 +1,261 @@ +# Silero VAD - pre-trained enterprise-grade Voice Activity Detector + +This directory provides VAD models on CPU/GPU. + +![](https://user-images.githubusercontent.com/36505480/198026365-8da383e0-5398-4a12-b7f8-22c2c0059512.png) + + +## VAD Interface + +For vad interface please see [](../../engine/vad/interface/). + +### Create Handdle + +```c++ +PPSHandle_t PPSVadCreateInstance(const char* conf_path); +``` + +### Destroy Handdle + +```c++ +int PPSVadDestroyInstance(PPSHandle_t instance); +``` + +### Reset Vad State + +```c++ +int PPSVadReset(PPSHandle_t instance); +``` + +Reset Vad state before processing next `wav`. + +### Get Chunk Size + +```c++ +int PPSVadChunkSizeSamples(PPSHandle_t instance); +``` + +This API will return chunk size in `sample` unit. +When do forward, we need feed `chunk size` samples, except last chunk. + +### Vad Forward + +```c++ +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); +``` + +Vad has below states: +```c++ +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; +``` + +If `PPSVadFeedForward` occur an error will return `PPS_VAD_ILLEGAL` state. + + +## Linux + +### Build Runtime +```bash +# cd /path/to/paddlespeech/runtime +cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON +cmake --build build +``` + +Since VAD using FastDeploy runtime, if you have another FastDeploy Library, you can using this command to build: + +```bash +# cd /path/to/paddlespeech/runtime +cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace//paddle/FastDeploy/build/Linux/x86_64/install +cmake --build build +``` + +`DFASTDEPLOY_INSTALL_DIR` is the directory of FastDeploy Library. + +### Run Demo + +After building success, we can do this to run demo under this example dir: + +```bash +bash run.sh +``` + +The output like these: + +```bash +/workspace//PaddleSpeech/runtime/engine/vad/nnet/vad.cc(88)::SetConfig sr=16 threshold=0.5 beam=0.15 frame_ms=32 min_silence_duration_ms=200 speech_pad_left_ms=0 speech_pad_right_ms=0[INFO] fastdeploy/runtime/runtime.cc(293)::CreateOrtBackend Runtime initialized with Backend::ORT in Device::CPU./workspace//PaddleSpeech/runtime/engine/vad/nnet/vad.cc(137)::Initialize init done.[SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [STA] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [END] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [STA] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SIL] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [END] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [STA] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [END] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [STA] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SPE] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [END] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] [SIL] +RTF=0.00774591 +speak start: 0.32 s, end: 2.464 s | speak start: 3.296 s, end: 4.64 s | speak start: 5.408 s, end: 7.872 s | speak start: 8.192 s, end: 10.72 s +vad_nnet_main done! +sr = 16000 +frame_ms = 32 +threshold = 0.5 +beam = 0.15 +min_silence_duration_ms = 200 +speech_pad_left_ms = 0 +speech_pad_right_ms = 0 +model_path = ./data/silero_vad/silero_vad.onnx +param_path = (default)num_cpu_thread = 1(default)/workspace//PaddleSpeech/runtime/engine/vad/nnet/vad.cc(88)::SetConfig sr=16 threshold=0.5 beam=0.15 frame_ms=32 min_silence_duration_ms=200 speech_pad_left_ms=0 speech_pad_right_ms=0[INFO] fastdeploy/runtime/runtime.cc(293)::CreateOrtBackend Runtime initialized with Backend::ORT in Device::CPU./workspace//PaddleSpeech/runtime/engine/vad/nnet/vad.cc(137)::Initialize init done. +1 1 1 1 1 1 1 1 1 1 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 4 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 4 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 4 1 1 1 1 1 1 1 1 1 1 2 3 3 3 3 3 3 3 3 3 3 3 1 1 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 1 1 1 1 1 4 1 1 1 1 1 1 1 1 1 +RTF=0.00778218 +vad_interface_main done! +``` + +## Android + +When to using on Android, please setup your `NDK` enverment before, then do as below: + +```bash +# cd /path/to/paddlespeech/runtime +bash build_android.sh +``` + +## Result + +| Arch | RTF | Runtime Size | +|--|--|--| +| x86_64 | 0.00778218 | | +| arm64-v8a | 0.00744745 | ~10.532MB | + +## Machine Information + +#### x86_64 + +The environment as below: + +```text +Architecture: x86_64 +CPU op-mode(s): 32-bit, 64-bit +Byte Order: Little Endian +CPU(s): 80 +On-line CPU(s) list: 0-79 +Thread(s) per core: 2 +Core(s) per socket: 20 +Socket(s): 2 +NUMA node(s): 2 +Vendor ID: GenuineIntel +CPU family: 6 +Model: 85 +Model name: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz +Stepping: 7 +CPU MHz: 2599.998 +BogoMIPS: 5199.99 +Hypervisor vendor: KVM +Virtualization type: full +L1d cache: 32K +L1i cache: 32K +L2 cache: 1024K +L3 cache: 33792K +NUMA node0 CPU(s): 0-39 +NUMA node1 CPU(s): 40-79 +Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc eagerfpu pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat umip pku ospke avx512_vnni spec_ctrl arch_capabilities +``` + +#### arm64-v8a + +```text +Processor : AArch64 Processor rev 14 (aarch64) +processor : 0 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x805 +CPU revision : 14 + +processor : 1 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x805 +CPU revision : 14 + +processor : 2 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x805 +CPU revision : 14 + +processor : 3 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x805 +CPU revision : 14 + +processor : 4 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x804 +CPU revision : 14 + +processor : 5 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x804 +CPU revision : 14 + +processor : 6 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x804 +CPU revision : 14 + +processor : 7 +BogoMIPS : 38.40 +Features : fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop +CPU implementer : 0x51 +CPU architecture: 8 +CPU variant : 0xd +CPU part : 0x804 +CPU revision : 14 + +Hardware : Qualcomm Technologies, Inc SM8150 +``` + + +## Download Pre-trained ONNX Model + +For developers' testing, model exported by VAD are provided below. Developers can download them directly. + +| 模型 | 大小 | 备注 | +| :----------------------------------------------------------- | :---- | :----------------------------------------------------------- | +| [silero-vad](https://bj.bcebos.com/paddlehub/fastdeploy/silero_vad.tgz) | 1.8MB | This model file is sourced from [snakers4/silero-vad](https://github.com/snakers4/silero-vad),MIT License | + + +## FastDeploy Runtime + +For FastDeploy software and hardware requements, and pre-released library please to see [FastDeploy](https://github.com/PaddlePaddle/FastDeploy): + +- 1. [FastDeploy Environment Requirements](https://github.com/PaddlePaddle/FastDeploy/docs/en/build_and_install/download_prebuilt_libraries.md). +- 2. [FastDeploy Precompiled Library](https://github.com/PaddlePaddle/FastDeploy/docs/en/build_and_install/download_prebuilt_libraries.md). + + +## Reference +* https://github.com/snakers4/silero-vad +* https://github.com/PaddlePaddle/FastDeploy/blob/develop/examples/audio/silero-vad/README.md diff --git a/runtime/examples/vad/conf/vad.ini b/runtime/examples/vad/conf/vad.ini new file mode 100644 index 00000000..c168c73b --- /dev/null +++ b/runtime/examples/vad/conf/vad.ini @@ -0,0 +1,11 @@ +[model] +model_path=./data/silero_vad/silero_vad.onnx + +[vad] +sr = 16000 # 16k +frame_ms = 32 # 32, 64, 96 for 16k +threshold = 0.5 +beam = 0.15 +min_silence_duration_ms = 200 +speech_pad_left_ms = 0 +speech_pad_right_ms = 0 diff --git a/runtime/examples/silero_vad/local/build.sh b/runtime/examples/vad/local/build.sh similarity index 100% rename from runtime/examples/silero_vad/local/build.sh rename to runtime/examples/vad/local/build.sh diff --git a/runtime/examples/silero_vad/local/build_android.sh b/runtime/examples/vad/local/build_android.sh similarity index 100% rename from runtime/examples/silero_vad/local/build_android.sh rename to runtime/examples/vad/local/build_android.sh diff --git a/runtime/examples/vad/local/decode.sh b/runtime/examples/vad/local/decode.sh new file mode 100755 index 00000000..ff0a0d44 --- /dev/null +++ b/runtime/examples/vad/local/decode.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -e + +conf=conf +data=data +exp=exp + +. utils/parse_options.sh + +mkdir -p $exp +ckpt_dir=$data/silero_vad +model=$ckpt_dir/silero_vad.onnx +test_wav=$data/silero_vad_sample.wav +conf_file=$conf/vad.ini + + +vad_nnet_main $model $test_wav +echo "vad_nnet_main done!" + +vad_interface_main $conf_file $test_wav +echo "vad_interface_main done!" + + diff --git a/runtime/examples/silero_vad/local/download.sh b/runtime/examples/vad/local/download.sh similarity index 100% rename from runtime/examples/silero_vad/local/download.sh rename to runtime/examples/vad/local/download.sh diff --git a/runtime/examples/vad/path.sh b/runtime/examples/vad/path.sh new file mode 100644 index 00000000..b4991111 --- /dev/null +++ b/runtime/examples/vad/path.sh @@ -0,0 +1,17 @@ +# This contains the locations of binarys build required for running the examples. + +unset GREP_OPTIONS + +ENGINE_ROOT=$PWD/../../ +ENGINE_BUILD=$ENGINE_ROOT/build/Linux/x86_64/engine/vad + +ENGINE_TOOLS=$ENGINE_ROOT/tools +TOOLS_BIN=$ENGINE_TOOLS/valgrind/install/bin + +[ -d $ENGINE_BUILD ] || { echo "Error: 'build/runtime' directory not found. please ensure that the project build successfully"; } + +export LC_AL=C + +export PATH=$PATH:$TOOLS_BIN:$ENGINE_BUILD/nnet:$ENGINE_BUILD/interface + +export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/runtime/examples/silero_vad/run.sh b/runtime/examples/vad/run.sh old mode 100644 new mode 100755 similarity index 77% rename from runtime/examples/silero_vad/run.sh rename to runtime/examples/vad/run.sh index 9707df1b..606a44f8 --- a/runtime/examples/silero_vad/run.sh +++ b/runtime/examples/vad/run.sh @@ -15,8 +15,8 @@ exp=exp mkdir -p $exp $data # 1. compile -if [ ! -d ${SPEECHX_BUILD} ]; then - pushd ${SPEECHX_ROOT} +if [ ! -d ${ENGINE_BUILD} ]; then + pushd ${ENGINE_ROOT} bash build.sh # build for android armv8/armv7 @@ -24,8 +24,6 @@ if [ ! -d ${SPEECHX_BUILD} ]; then popd fi -ckpt_dir=$data/silero_vad -wav=$data/silero_vad_sample.wav if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then ./local/download.sh diff --git a/runtime/examples/silero_vad/utils b/runtime/examples/vad/utils similarity index 100% rename from runtime/examples/silero_vad/utils rename to runtime/examples/vad/utils From f0ef6f1cafad81cdf25ba7de549224009579b006 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 13 Mar 2023 15:17:24 +0800 Subject: [PATCH 17/50] [runtime] vad jni demo (#3027) * android vad jni demo * update src * rename --- runtime/examples/android/VadJni/.gitignore | 15 ++ .../examples/android/VadJni/.idea/.gitignore | 3 + runtime/examples/android/VadJni/.idea/.name | 1 + .../android/VadJni/.idea/compiler.xml | 6 + .../VadJni/.idea/deploymentTargetDropDown.xml | 17 ++ .../examples/android/VadJni/.idea/gradle.xml | 19 ++ .../examples/android/VadJni/.idea/misc.xml | 10 + runtime/examples/android/VadJni/.idea/vcs.xml | 6 + .../examples/android/VadJni/app/.gitignore | 2 + .../examples/android/VadJni/app/build.gradle | 129 ++++++++++++ .../examples/android/VadJni/app/libs/.gitkeep | 0 .../android/VadJni/app/proguard-rules.pro | 21 ++ .../vadjni/ExampleInstrumentedTest.java | 26 +++ .../VadJni/app/src/main/AndroidManifest.xml | 25 +++ .../VadJni/app/src/main/assets/.gitkeep | 0 .../VadJni/app/src/main/cpp/CMakeLists.txt | 59 ++++++ .../VadJni/app/src/main/cpp/native-lib.cpp | 57 ++++++ .../VadJni/app/src/main/cpp/vad_interface.h | 46 +++++ .../paddlespeech/vadjni/MainActivity.java | 50 +++++ .../drawable-v24/ic_launcher_foreground.xml | 30 +++ .../res/drawable/ic_launcher_background.xml | 170 ++++++++++++++++ .../app/src/main/res/layout/activity_main.xml | 28 +++ .../res/mipmap-anydpi-v26/ic_launcher.xml | 5 + .../mipmap-anydpi-v26/ic_launcher_round.xml | 5 + .../res/mipmap-anydpi-v33/ic_launcher.xml | 6 + .../src/main/res/mipmap-hdpi/ic_launcher.webp | Bin 0 -> 1404 bytes .../res/mipmap-hdpi/ic_launcher_round.webp | Bin 0 -> 2898 bytes .../src/main/res/mipmap-mdpi/ic_launcher.webp | Bin 0 -> 982 bytes .../res/mipmap-mdpi/ic_launcher_round.webp | Bin 0 -> 1772 bytes .../main/res/mipmap-xhdpi/ic_launcher.webp | Bin 0 -> 1900 bytes .../res/mipmap-xhdpi/ic_launcher_round.webp | Bin 0 -> 3918 bytes .../main/res/mipmap-xxhdpi/ic_launcher.webp | Bin 0 -> 2884 bytes .../res/mipmap-xxhdpi/ic_launcher_round.webp | Bin 0 -> 5914 bytes .../main/res/mipmap-xxxhdpi/ic_launcher.webp | Bin 0 -> 3844 bytes .../res/mipmap-xxxhdpi/ic_launcher_round.webp | Bin 0 -> 7778 bytes .../app/src/main/res/values-night/themes.xml | 16 ++ .../VadJni/app/src/main/res/values/colors.xml | 10 + .../app/src/main/res/values/strings.xml | 3 + .../VadJni/app/src/main/res/values/themes.xml | 16 ++ .../app/src/main/res/xml/backup_rules.xml | 13 ++ .../main/res/xml/data_extraction_rules.xml | 19 ++ runtime/examples/android/VadJni/build.gradle | 5 + .../examples/android/VadJni/gradle.properties | 21 ++ .../VadJni/gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 59203 bytes .../gradle/wrapper/gradle-wrapper.properties | 6 + runtime/examples/android/VadJni/gradlew | 185 ++++++++++++++++++ runtime/examples/android/VadJni/gradlew.bat | 89 +++++++++ .../examples/android/VadJni/settings.gradle | 16 ++ 48 files changed, 1135 insertions(+) create mode 100644 runtime/examples/android/VadJni/.gitignore create mode 100644 runtime/examples/android/VadJni/.idea/.gitignore create mode 100644 runtime/examples/android/VadJni/.idea/.name create mode 100644 runtime/examples/android/VadJni/.idea/compiler.xml create mode 100644 runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml create mode 100644 runtime/examples/android/VadJni/.idea/gradle.xml create mode 100644 runtime/examples/android/VadJni/.idea/misc.xml create mode 100644 runtime/examples/android/VadJni/.idea/vcs.xml create mode 100644 runtime/examples/android/VadJni/app/.gitignore create mode 100644 runtime/examples/android/VadJni/app/build.gradle create mode 100644 runtime/examples/android/VadJni/app/libs/.gitkeep create mode 100644 runtime/examples/android/VadJni/app/proguard-rules.pro create mode 100644 runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java create mode 100644 runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/assets/.gitkeep create mode 100644 runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt create mode 100644 runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp create mode 100644 runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h create mode 100644 runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java create mode 100644 runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-anydpi-v33/ic_launcher.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-hdpi/ic_launcher.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-mdpi/ic_launcher.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xhdpi/ic_launcher.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp create mode 100644 runtime/examples/android/VadJni/app/src/main/res/values-night/themes.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/values/colors.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/values/strings.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/values/themes.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/xml/backup_rules.xml create mode 100644 runtime/examples/android/VadJni/app/src/main/res/xml/data_extraction_rules.xml create mode 100644 runtime/examples/android/VadJni/build.gradle create mode 100644 runtime/examples/android/VadJni/gradle.properties create mode 100644 runtime/examples/android/VadJni/gradle/wrapper/gradle-wrapper.jar create mode 100644 runtime/examples/android/VadJni/gradle/wrapper/gradle-wrapper.properties create mode 100755 runtime/examples/android/VadJni/gradlew create mode 100644 runtime/examples/android/VadJni/gradlew.bat create mode 100644 runtime/examples/android/VadJni/settings.gradle diff --git a/runtime/examples/android/VadJni/.gitignore b/runtime/examples/android/VadJni/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/runtime/examples/android/VadJni/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/runtime/examples/android/VadJni/.idea/.gitignore b/runtime/examples/android/VadJni/.idea/.gitignore new file mode 100644 index 00000000..26d33521 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/runtime/examples/android/VadJni/.idea/.name b/runtime/examples/android/VadJni/.idea/.name new file mode 100644 index 00000000..b5712d1e --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.name @@ -0,0 +1 @@ +VadJni \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/compiler.xml b/runtime/examples/android/VadJni/.idea/compiler.xml new file mode 100644 index 00000000..fb7f4a8a --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml new file mode 100644 index 00000000..f26362be --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/gradle.xml b/runtime/examples/android/VadJni/.idea/gradle.xml new file mode 100644 index 00000000..a2d7c213 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/misc.xml b/runtime/examples/android/VadJni/.idea/misc.xml new file mode 100644 index 00000000..bdd92780 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/vcs.xml b/runtime/examples/android/VadJni/.idea/vcs.xml new file mode 100644 index 00000000..4fce1d86 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/.gitignore b/runtime/examples/android/VadJni/app/.gitignore new file mode 100644 index 00000000..44399f1d --- /dev/null +++ b/runtime/examples/android/VadJni/app/.gitignore @@ -0,0 +1,2 @@ +/build +/cache diff --git a/runtime/examples/android/VadJni/app/build.gradle b/runtime/examples/android/VadJni/app/build.gradle new file mode 100644 index 00000000..f2025a21 --- /dev/null +++ b/runtime/examples/android/VadJni/app/build.gradle @@ -0,0 +1,129 @@ +plugins { + id 'com.android.application' +} + +android { + namespace 'com.baidu.paddlespeech.vadjni' + compileSdk 33 + ndkVersion '23.1.7779620' + + defaultConfig { + applicationId "com.baidu.paddlespeech.vadjni" + minSdk 21 + targetSdk 33 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + externalNativeBuild { + cmake { + arguments '-DANDROID_PLATFORM=android-21', '-DANDROID_STL=c++_shared', "-DANDROID_TOOLCHAIN=clang" + abiFilters 'arm64-v8a' + cppFlags "-std=c++11" + } + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + externalNativeBuild { + cmake { + path file('src/main/cpp/CMakeLists.txt') + version '3.22.1' + } + } + buildFeatures { + viewBinding true + } + sourceSets { + main { + jniLibs.srcDirs = ['libs'] + } + } +} + +dependencies { + // Dependency on local binaries + implementation fileTree(dir: 'libs', include: ['*.jar']) + // Dependency on a remote binary + implementation 'androidx.appcompat:appcompat:1.4.1' + implementation 'com.google.android.material:material:1.5.0' + implementation 'androidx.constraintlayout:constraintlayout:2.1.3' + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.3' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' +} + +def CXX_LIB = [ +// [ +// 'src' : 'https://bj.bcebos.com/fastdeploy/dev/android/fastdeploy-android-with-text-0.0.0-shared.tgz', +// 'dest': 'libs', +// 'name': 'fastdeploy-android-latest-shared-dev' +// ] +] + +task downloadAndExtractLibs(type: DefaultTask) { + doFirst { + println "[INFO] Downloading and extracting fastdeploy android c++ lib ..." + } + doLast { + String cachePath = "cache" + if (!file("${cachePath}").exists()) { + mkdir "${cachePath}" + } + + CXX_LIB.eachWithIndex { lib, index -> + + String[] libPaths = lib.src.split("/") + String sdkName = lib.name + String libName = libPaths[libPaths.length - 1] + libName = libName.substring(0, libName.indexOf("tgz") - 1) + String cacheName = cachePath + "/" + "${libName}.tgz" + + String libDir = lib.dest + "/" + libName + String sdkDir = lib.dest + "/" + sdkName + + boolean copyFiles = false + if (!file("${sdkDir}").exists()) { + // Download lib and rename to sdk name later. + if (!file("${cacheName}").exists()) { + println "[INFO] Downloading ${lib.src} -> ${cacheName}" + ant.get(src: lib.src, dest: file("${cacheName}")) + } + copyFiles = true + } + + if (copyFiles) { + println "[INFO] Taring ${cacheName} -> ${libDir}" + copy { from(tarTree("${cacheName}")) into("${lib.dest}") } + if (!libName.equals(sdkName)) { + if (file("${sdkDir}").exists()) { + delete("${sdkDir}") + println "[INFO] Remove old ${sdkDir}" + } + mkdir "${sdkDir}" + println "[INFO] Coping ${libDir} -> ${sdkDir}" + copy { from("${libDir}") into("${sdkDir}") } + delete("${libDir}") + println "[INFO] Removed ${libDir}" + println "[INFO] Update ${sdkDir} done!" + } + } else { + println "[INFO] ${sdkDir} already exists!" + println "[WARN] Please delete ${cacheName} and ${sdkDir} " + + "if you want to UPDATE ${sdkName} c++ lib. Then, rebuild this sdk." + } + } + } +} + +preBuild.dependsOn downloadAndExtractLibs \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/libs/.gitkeep b/runtime/examples/android/VadJni/app/libs/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/runtime/examples/android/VadJni/app/proguard-rules.pro b/runtime/examples/android/VadJni/app/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/runtime/examples/android/VadJni/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java new file mode 100644 index 00000000..5c02120b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package com.baidu.paddlespeech.vadjni; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + assertEquals("com.baidu.paddlespeech.vadjni", appContext.getPackageName()); + } +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000..d8076922 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/assets/.gitkeep b/runtime/examples/android/VadJni/app/src/main/assets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..5eaa053b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,59 @@ +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html + +# Sets the minimum version of CMake required to build the native library. + +cmake_minimum_required(VERSION 3.22.1) + +# Declares and names the project. + +project("vadjni") + + +set(PPS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +# Creates and names a library, sets it as either STATIC +# or SHARED, and provides the relative paths to its source code. +# You can define multiple libraries, and CMake builds them for you. +# Gradle automatically packages shared libraries with your APK. + +add_library( # Sets the name of the library. + vadjni + + # Sets the library as a shared library. + SHARED + + # Provides a relative path to your source file(s). + native-lib.cpp) + +# Searches for a specified prebuilt library and stores the path as a +# variable. Because CMake includes system libraries in the search path by +# default, you only need to specify the name of the public NDK library +# you want to add. CMake verifies that the library exists before +# completing its build. + +find_library( # Sets the name of the path variable. + log-lib + + # Specifies the name of the NDK library that + # you want CMake to locate. + log) + +# Specifies libraries CMake should link to your target library. You +# can link multiple libraries, such as libraries you define in this +# build script, prebuilt third-party libraries, or system libraries. + +message(STATUS "PPS_DIR=${PPS_DIR}") +target_link_libraries( # Specifies the target library. + vadjni + ${PPS_DIR}/libfastdeploy.so + ${PPS_DIR}/libonnxruntime.so + ${PPS_DIR}/libgflags_nothreads.a + ${PPS_DIR}/libbase.a + ${PPS_DIR}/libpps_vad.a + ${PPS_DIR}/libpps_vad_interface.a + # Links the target library to the log library + # included in the NDK. + ${log-lib}) \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp new file mode 100644 index 00000000..e80ac2e4 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp @@ -0,0 +1,57 @@ + +#include +#include "vad_interface.h" +#include + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_stringFromJNI( + JNIEnv* env, + jobject /* this */) { + std::string hello = "Hello from C++"; + return env->NewStringUTF(hello.c_str()); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_createInstance( + JNIEnv* env, + jobject thiz, + jstring conf_path){ + const char* path = env->GetStringUTFChars(conf_path, JNI_FALSE); + PPSHandle_t handle = PPSVadCreateInstance(path); + + return (jlong)(handle); + return 0; +} + + +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_destroyInstance(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadDestroyInstance(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_reset(JNIEnv *env, jobject thiz, jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadReset(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_chunkSizeSamples(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadChunkSizeSamples(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_feedForward(JNIEnv *env, jobject thiz, + jlong instance, jfloatArray chunk) { + PPSHandle_t handle = (PPSHandle_t)(instance); + jsize num_elms = env->GetArrayLength(chunk); + jfloat* chunk_ptr = env->GetFloatArrayElements(chunk, JNI_FALSE); + return (jint)PPSVadFeedForward(handle, (float*)chunk_ptr, (int)num_elms); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h new file mode 100644 index 00000000..5d7ca709 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java new file mode 100644 index 00000000..3b463280 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java @@ -0,0 +1,50 @@ +package com.baidu.paddlespeech.vadjni; + +import androidx.appcompat.app.AppCompatActivity; + +import android.os.Bundle; +import android.widget.Button; +import android.widget.TextView; + +import com.baidu.paddlespeech.vadjni.databinding.ActivityMainBinding; + +public class MainActivity extends AppCompatActivity { + + // Used to load the 'vadjni' library on application startup. + static { + System.loadLibrary("vadjni"); + } + + private ActivityMainBinding binding; + private long instance; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + binding = ActivityMainBinding.inflate(getLayoutInflater()); + setContentView(binding.getRoot()); + + // Example of a call to a native method + TextView tv = binding.sampleText; + tv.setText(stringFromJNI()); + + Button lw = binding.loadWav; + } + + /** + * A native method that is implemented by the 'vadjni' native library, + * which is packaged with this application. + */ + public native String stringFromJNI(); + + public static native long createInstance(String config_path); + + public static native int destroyInstance(long instance); + + public static native int reset(long instance); + + public static native int chunkSizeSamples(long instance); + + public static native int feedForward(long instance, float[] chunk); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000..2b068d11 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000..07d5da9c --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..c9938516 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,28 @@ + + + + + +