From f37f34d3cee5841469084a4f0ef955faf3d434f5 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Fri, 30 Dec 2022 15:28:06 +0800 Subject: [PATCH] rm ds2 && rm boost --- .pre-commit-config.yaml | 4 +- speechx/CMakeLists.txt | 18 - speechx/build.sh | 17 +- speechx/examples/u2pp_ol/wenetspeech/path.sh | 4 +- speechx/speechx/asr/decoder/CMakeLists.txt | 59 +- .../asr/decoder/ctc_beam_search_decoder.cc | 313 ----- .../asr/decoder/ctc_beam_search_decoder.h | 73 - .../decoder/ctc_beam_search_decoder_main.cc | 167 --- .../asr/decoder/ctc_decoders/.gitignore | 9 - .../ctc_decoders/ctc_beam_search_decoder.cpp | 607 --------- .../ctc_decoders/ctc_beam_search_decoder.h | 175 --- .../ctc_decoders/ctc_greedy_decoder.cpp | 61 - .../decoder/ctc_decoders/ctc_greedy_decoder.h | 35 - .../decoder/ctc_decoders/decoder_utils.cpp | 193 --- .../asr/decoder/ctc_decoders/decoder_utils.h | 111 -- .../asr/decoder/ctc_decoders/path_trie.cpp | 164 --- .../asr/decoder/ctc_decoders/path_trie.h | 82 -- .../asr/decoder/ctc_decoders/scorer.cpp | 232 ---- .../speechx/asr/decoder/ctc_decoders/scorer.h | 114 -- .../asr/decoder/nnet_logprob_decoder_main.cc | 77 -- speechx/speechx/asr/decoder/param.h | 3 +- speechx/speechx/asr/nnet/CMakeLists.txt | 24 +- speechx/speechx/asr/nnet/ds2_nnet.cc | 218 --- speechx/speechx/asr/nnet/ds2_nnet.h | 97 -- speechx/speechx/asr/nnet/ds2_nnet_main.cc | 142 -- speechx/speechx/asr/nnet/nnet_producer.cc | 1 - speechx/speechx/asr/recognizer/CMakeLists.txt | 49 +- speechx/speechx/asr/recognizer/recognizer.cc | 70 - speechx/speechx/asr/recognizer/recognizer.h | 70 - .../speechx/asr/recognizer/recognizer_main.cc | 105 -- .../frontend/audio/cmvn_json2kaldi_main.cc | 46 +- speechx/speechx/common/utils/picojson.h | 1202 +++++++++++++++++ 32 files changed, 1265 insertions(+), 3277 deletions(-) delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/.gitignore delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/path_trie.h delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp delete mode 100644 speechx/speechx/asr/decoder/ctc_decoders/scorer.h delete mode 100644 speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.cc delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet.h delete mode 100644 speechx/speechx/asr/nnet/ds2_nnet_main.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.cc delete mode 100644 speechx/speechx/asr/recognizer/recognizer.h delete mode 100644 speechx/speechx/asr/recognizer/recognizer_main.cc create mode 100644 speechx/speechx/common/utils/picojson.h diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15b842d55..994619478 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,13 +57,13 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders|speechx/speechx/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 45bf54194..cfce63dd9 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -44,9 +44,6 @@ option(TEST_DEBUG "option for debug" OFF) option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) -option(USING_U2 "compile u2 model." ON) -option(USING_DS2 "compile with ds2 model." OFF) - option(USING_GPU "u2 compute on GPU." OFF) ############################################################################### @@ -56,21 +53,6 @@ include(gflags) include(glog) -# boost -# include(boost) # not work -set(boost_SOURCE_DIR ${fc_patch}/boost-src) -set(BOOST_ROOT ${boost_SOURCE_DIR}) -include_directories(${boost_SOURCE_DIR}) -link_directories(${boost_SOURCE_DIR}/stage/lib) - -# Eigen -include(eigen) -find_package(Eigen3 REQUIRED) - -# Kenlm -include(kenlm) -add_dependencies(kenlm eigen boost) - #openblas include(openblas) diff --git a/speechx/build.sh b/speechx/build.sh index 7655f9635..94d250f5a 100755 --- a/speechx/build.sh +++ b/speechx/build.sh @@ -4,20 +4,5 @@ set -xe # the build script had verified in the paddlepaddle docker image. # please follow the instruction below to install PaddlePaddle image. # https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html -boost_SOURCE_DIR=$PWD/fc_patch/boost-src -if [ ! -d ${boost_SOURCE_DIR} ]; then wget -c https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz - tar xzfv boost_1_75_0.tar.gz - mkdir -p $PWD/fc_patch - mv boost_1_75_0 ${boost_SOURCE_DIR} - cd ${boost_SOURCE_DIR} - bash ./bootstrap.sh - ./b2 - cd - - echo -e "\n" -fi - -#rm -rf build -mkdir -p build - -cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} +cmake -B build cmake --build build -j diff --git a/speechx/examples/u2pp_ol/wenetspeech/path.sh b/speechx/examples/u2pp_ol/wenetspeech/path.sh index ec278bd3d..9518db116 100644 --- a/speechx/examples/u2pp_ol/wenetspeech/path.sh +++ b/speechx/examples/u2pp_ol/wenetspeech/path.sh @@ -3,7 +3,7 @@ unset GREP_OPTIONS SPEECHX_ROOT=$PWD/../../../ -SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx +SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx/asr SPEECHX_TOOLS=$SPEECHX_ROOT/tools TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin @@ -12,7 +12,7 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer +export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/../common/frontend/audio:$SPEECHX_BUILD/recognizer PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);") export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH diff --git a/speechx/speechx/asr/decoder/CMakeLists.txt b/speechx/speechx/asr/decoder/CMakeLists.txt index 93014fb90..b2f507080 100644 --- a/speechx/speechx/asr/decoder/CMakeLists.txt +++ b/speechx/speechx/asr/decoder/CMakeLists.txt @@ -1,55 +1,22 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) - set(srcs) - -if (USING_DS2) list(APPEND srcs - ctc_decoders/decoder_utils.cpp - ctc_decoders/path_trie.cpp - ctc_decoders/scorer.cpp - ctc_beam_search_decoder.cc - ctc_tlg_decoder.cc + ctc_prefix_beam_search_decoder.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - ctc_prefix_beam_search_decoder.cc - ) -endif() add_library(decoder STATIC ${srcs}) -target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) +target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) # test -if (USING_DS2) - set(BINS - ctc_beam_search_decoder_main - nnet_logprob_decoder_main - ctc_tlg_decoder_main - ) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - ctc_prefix_beam_search_decoder_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() +set(TEST_BINS + ctc_prefix_beam_search_decoder_main +) -endif() +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc deleted file mode 100644 index 6e3a0d136..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.cc +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - - -#include "decoder/ctc_beam_search_decoder.h" - -#include "base/common.h" -#include "decoder/ctc_decoders/decoder_utils.h" -#include "utils/file_utils.h" - -namespace ppspeech { - -using std::vector; -using FSTMATCH = fst::SortedMatcher; - -CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) - : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) { - LOG(INFO) << "dict path: " << opts_.dict_file; - if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { - LOG(INFO) << "load the dict failed"; - } - LOG(INFO) << "read the vocabulary success, dict size: " - << vocabulary_.size(); - - LOG(INFO) << "language model path: " << opts_.lm_path; - if (opts_.lm_path != "") { - init_ext_scorer_ = std::make_shared( - opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); - } - - CHECK_EQ(opts_.blank, 0); - - auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " "); - space_id_ = it - vocabulary_.begin(); - // if no space in vocabulary - if (static_cast(space_id_) >= vocabulary_.size()) { - space_id_ = -2; - } -} - -void CTCBeamSearch::Reset() { - // num_frame_decoded_ = 0; - // ResetPrefixes(); - InitDecoder(); -} - -void CTCBeamSearch::InitDecoder() { - num_frame_decoded_ = 0; - // ResetPrefixes(); - prefixes_.clear(); - - root_ = std::make_shared(); - root_->score = root_->log_prob_b_prev = 0.0; - prefixes_.push_back(root_.get()); - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - auto fst_dict = - static_cast(init_ext_scorer_->dictionary); - fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); - root_->set_dictionary(dict_ptr); - - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root_->set_matcher(matcher); - } -} - -void CTCBeamSearch::Decode( - std::shared_ptr decodable) { - return; -} - -// todo rename, refactor -void CTCBeamSearch::AdvanceDecode( - const std::shared_ptr& decodable) { - while (1) { - vector> likelihood; - vector frame_prob; - bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); - if (flag == false) break; - likelihood.push_back(frame_prob); - AdvanceDecoding(likelihood); - } -} - -void CTCBeamSearch::ResetPrefixes() { - for (size_t i = 0; i < prefixes_.size(); i++) { - if (prefixes_[i] != nullptr) { - delete prefixes_[i]; - prefixes_[i] = nullptr; - } - } - prefixes_.clear(); -} - -int CTCBeamSearch::DecodeLikelihoods(const vector>& probs, - const vector& nbest_words) { - kaldi::Timer timer; - AdvanceDecoding(probs); - LOG(INFO) << "ctc decoding elapsed time(s) " - << static_cast(timer.Elapsed()) / 1000.0f; - return 0; -} - -vector> CTCBeamSearch::GetNBestPath(int n) { - int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size); - return get_beam_search_result(prefixes_, vocabulary_, beam_size); -} - -vector> CTCBeamSearch::GetNBestPath() { - return GetNBestPath(-1); -} - -string CTCBeamSearch::GetBestPath() { - std::vector> result; - result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); - return result[0].second; -} - -string CTCBeamSearch::GetFinalBestPath() { - CalculateApproxScore(); - LMRescore(); - return GetBestPath(); -} - -void CTCBeamSearch::AdvanceDecoding(const vector>& probs) { - size_t num_time_steps = probs.size(); - size_t beam_size = opts_.beam_size; - double cutoff_prob = opts_.cutoff_prob; - size_t cutoff_top_n = opts_.cutoff_top_n; - - vector> probs_seq(probs.size(), - vector(probs[0].size(), 0)); - - int row = probs.size(); - int col = probs[0].size(); - for (int i = 0; i < row; i++) { - for (int j = 0; j < col; j++) { - probs_seq[i][j] = static_cast(probs[i][j]); - } - } - - for (size_t time_step = 0; time_step < num_time_steps; time_step++) { - const auto& prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (init_ext_scorer_ != nullptr) { - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort(prefixes_.begin(), - prefixes_.begin() + num_prefixes_, - prefix_compare); - - if (num_prefixes_ == 0) { - continue; - } - min_cutoff = prefixes_[num_prefixes_ - 1]->score + - std::log(prob[opts_.blank]) - - std::max(0.0, init_ext_scorer_->beta); - - full_beam = (num_prefixes_ == beam_size); - } - - vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - - // loop over chars - size_t log_prob_idx_len = log_prob_idx.size(); - for (size_t index = 0; index < log_prob_idx_len; index++) { - SearchOneChar(full_beam, log_prob_idx[index], min_cutoff); - } - - prefixes_.clear(); - - // update log probs - root_->iterate_to_vec(prefixes_); - // only preserve top beam_size prefixes_ - if (prefixes_.size() >= beam_size) { - std::nth_element(prefixes_.begin(), - prefixes_.begin() + beam_size, - prefixes_.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes_.size(); ++i) { - prefixes_[i]->remove(); - } - } // end if - num_frame_decoded_++; - } // end for probs_seq -} - -int32 CTCBeamSearch::SearchOneChar( - const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff) { - size_t beam_size = opts_.beam_size; - const auto& c = log_prob_idx.first; - const auto& log_prob_c = log_prob_idx.second; - size_t prefixes_len = std::min(prefixes_.size(), beam_size); - - for (size_t i = 0; i < prefixes_len; ++i) { - auto prefix = prefixes_[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - - if (c == opts_.blank) { - prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - - // repeated character - if (c == prefix->character) { - // p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1}) - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); - } - - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1}) - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - // p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1}) - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (init_ext_scorer_ != nullptr && - (c == space_id_ || init_ext_scorer_->is_character_based())) { - PathTrie* prefix_to_score = nullptr; - // skip scoring the space - if (init_ext_scorer_->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - vector ngram; - ngram = init_ext_scorer_->make_ngram(prefix_to_score); - // lm score: p_{lm}(W)^{\alpha} + \beta - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - log_p += score; - log_p += init_ext_scorer_->beta; - } - // p_{nb}(l;x_{1:t}) - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - return 0; -} - -void CTCBeamSearch::CalculateApproxScore() { - size_t beam_size = opts_.beam_size; - size_t num_prefixes_ = std::min(prefixes_.size(), beam_size); - std::sort( - prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - double approx_ctc = prefixes_[i]->score; - if (init_ext_scorer_ != nullptr) { - vector output; - prefixes_[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = init_ext_scorer_->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta; - // remove language model weight: - approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) * - init_ext_scorer_->alpha; - } - prefixes_[i]->approx_ctc = approx_ctc; - } -} - -void CTCBeamSearch::LMRescore() { - size_t beam_size = opts_.beam_size; - if (init_ext_scorer_ != nullptr && - !init_ext_scorer_->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) { - auto prefix = prefixes_[i]; - if (!prefix->is_empty() && prefix->character != space_id_) { - float score = 0.0; - vector ngram = init_ext_scorer_->make_ngram(prefix); - score = init_ext_scorer_->get_log_cond_prob(ngram) * - init_ext_scorer_->alpha; - score += init_ext_scorer_->beta; - prefix->score += score; - } - } - } -} - -} // namespace ppspeech diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h deleted file mode 100644 index f06d88e32..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#pragma once - -#include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_decoders/path_trie.h" -#include "decoder/ctc_decoders/scorer.h" -#include "decoder/decoder_itf.h" - -namespace ppspeech { - -class CTCBeamSearch : public DecoderBase { - public: - explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); - ~CTCBeamSearch() {} - - void InitDecoder(); - - void Reset(); - - void AdvanceDecode( - const std::shared_ptr& decodable); - - void Decode(std::shared_ptr decodable); - - std::string GetBestPath(); - std::vector> GetNBestPath(); - std::vector> GetNBestPath(int n); - std::string GetFinalBestPath(); - - std::string GetPartialResult() { - CHECK(false) << "Not implement."; - return {}; - } - - int DecodeLikelihoods(const std::vector>& probs, - const std::vector& nbest_words); - - private: - void ResetPrefixes(); - - int32 SearchOneChar(const bool& full_beam, - const std::pair& log_prob_idx, - const BaseFloat& min_cutoff); - void CalculateApproxScore(); - void LMRescore(); - void AdvanceDecoding(const std::vector>& probs); - - CTCBeamSearchOptions opts_; - std::shared_ptr init_ext_scorer_; // todo separate later - std::vector vocabulary_; // todo remove later - int space_id_; - std::shared_ptr root_; - std::vector prefixes_; - - DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc deleted file mode 100644 index ab0376b6b..000000000 --- a/speechx/speechx/asr/decoder/ctc_beam_search_decoder_main.cc +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// used by deepspeech2 - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "", "language model"); -DEFINE_int32(receptive_field_length, - 7, - "receptive field of two CNN(kernel=3) downsampling module."); -DEFINE_int32(subsampling_rate, - 4, - "two CNN(kernel=3) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); -DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test ds2 online decoder by feeding speech feature -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - CHECK_NE(FLAGS_result_wspecifier, ""); - CHECK_NE(FLAGS_feature_rspecifier, ""); - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - std::string model_path = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_path; - LOG(INFO) << "model param: " << model_params; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - decoder.InitDecoder(); - - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } - std::string result; - result = decoder.GetFinalBestPath(); - decodable->Reset(); - decoder.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore b/speechx/speechx/asr/decoder/ctc_decoders/.gitignore deleted file mode 100644 index 0b1046ae8..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -ThreadPool/ -build/ -dist/ -kenlm/ -openfst-1.6.3/ -openfst-1.6.3.tar.gz -swig_decoders.egg-info/ -decoders_wrap.cxx -swig_decoders.py diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp deleted file mode 100644 index ebea5c222..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.cpp +++ /dev/null @@ -1,607 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_beam_search_decoder.h" - -#include -#include -#include -#include -#include -#include - -#include "ThreadPool.h" -#include "fst/fstlib.h" - -#include "decoder_utils.h" -#include "path_trie.h" - -using FSTMATCH = fst::SortedMatcher; - - -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - PathTrie root; - root.score = root.log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); - - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } - - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute approximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - return get_beam_search_result(prefixes, vocabulary, beam_size); -} - - -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoding, - probs_split[i], - vocabulary, - beam_size, - cutoff_prob, - cutoff_top_n, - ext_scorer, - blank_id)); - } - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) { - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = - static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root->set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root->set_matcher(matcher); - } -} - -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - // vocabulary.size() + 1, - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // init prefixes' root - // - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort(prefixes.begin(), - prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes - 1]->score + - std::log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = log_sum_exp( - prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = - log_sum_exp(prefix->log_prob_nb_cur, - log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over vocabulary - - prefixes.clear(); - // update log probs - - root->iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } - } - } // end of loop over time - - return; -} - - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer) { - auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if ((size_t)space_id >= vocabulary.size()) { - space_id = -2; - } - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score += score; - } - } - } - - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - double approx_ctc = prefixes[i]->score; - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= - (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; - } - prefixes[i]->approx_ctc = approx_ctc; - } - - std::vector> res = - get_beam_search_result(prefixes, vocabulary, beam_size); - - // pay back the last word of each prefix that doesn't end with space (for - // decoding by chunk) - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - auto prefix = prefixes[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = - ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - prefix->score -= score; - } - } - } - return res; -} - - -void free_storage(std::unique_ptr &storage) { - storage = nullptr; -} - - -CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {} - -CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch( - const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id) - : batch_size(batch_size), - beam_size(beam_size), - num_processes(num_processes), - cutoff_prob(cutoff_prob), - cutoff_top_n(cutoff_top_n), - ext_scorer(ext_scorer), - blank_id(blank_id) { - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - this->vocabulary = vocabulary; - for (size_t i = 0; i < batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, ext_scorer); - } -}; - -/** - * Input - * probs_split: shape [B, T, D] - */ -void CtcBeamSearchDecoderBatch::next( - const std::vector>> &probs_split, - const std::vector &has_value) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - size_t num_has_value = 0; - for (int i = 0; i < has_value.size(); i++) - if (has_value[i] == "true") num_has_value += 1; - ThreadPool pool(std::min(num_processes, num_has_value)); - // number of samples - size_t probs_num = probs_split.size(); - VALID_CHECK_EQ(this->batch_size, - probs_num, - "The batch size of the current input data should be same " - "with the input data before"); - - // enqueue the tasks of decoding - std::vector> res; - for (size_t i = 0; i < batch_size; ++i) { - if (has_value[i] == "true") { - res.emplace_back(pool.enqueue( - ctc_beam_search_decode_chunk, - std::ref(this->decoder_storage_vector[i]->root), - std::ref(this->decoder_storage_vector[i]->prefixes), - probs_split[i], - this->vocabulary, - this->beam_size, - this->cutoff_prob, - this->cutoff_top_n, - this->ext_scorer, - this->blank_id)); - } - } - - for (size_t i = 0; i < batch_size; ++i) { - res[i].get(); - } - return; -}; - -/** - * Return - * batch_result: shape[B, beam_size,(-approx_ctc score, string)] - */ -std::vector>> -CtcBeamSearchDecoderBatch::decode() { - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < this->batch_size; ++i) { - res.emplace_back( - pool.enqueue(get_decode_result, - std::ref(this->decoder_storage_vector[i]->prefixes), - this->vocabulary, - this->beam_size, - this->ext_scorer)); - } - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < this->batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} - - -/** - * reset the state of ctcBeamSearchDecoderBatch - */ -void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n) { - this->batch_size = batch_size; - this->beam_size = beam_size; - this->num_processes = num_processes; - this->cutoff_prob = cutoff_prob; - this->cutoff_top_n = cutoff_top_n; - - VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); - VALID_CHECK_GT( - this->num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(this->num_processes); - // number of samples - // enqueue the tasks of decoding - std::vector> res; - size_t storage_size = decoder_storage_vector.size(); - for (size_t i = 0; i < storage_size; i++) { - res.emplace_back(pool.enqueue( - free_storage, std::ref(this->decoder_storage_vector[i]))); - } - for (size_t i = 0; i < storage_size; ++i) { - res[i].get(); - } - std::vector>().swap( - decoder_storage_vector); - for (size_t i = 0; i < this->batch_size; i++) { - this->decoder_storage_vector.push_back( - std::unique_ptr( - new CtcBeamSearchDecoderStorage())); - ctc_beam_search_decode_chunk_begin( - this->decoder_storage_vector[i]->root, this->ext_scorer); - } -} \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h deleted file mode 100644 index 92d2b855f..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_beam_search_decoder.h +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_BEAM_SEARCH_DECODER_H_ -#define CTC_BEAM_SEARCH_DECODER_H_ - -#include -#include -#include - -#include "scorer.h" - -/* CTC Beam Search Decoder - - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. -*/ -std::vector> ctc_beam_search_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - - -/* CTC Beam Search Decoder for batch data - - * Parameters: - * probs_seq: 3-D vector that each element is a 2-D vector that can be used - * by ctc_beam_search_decoder(). - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * num_processes: Number of threads for beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A 2-D vector that each element is a vector of beam search decoding - * result for one audio sample. -*/ -std::vector>> -ctc_beam_search_decoding_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - Scorer *ext_scorer = nullptr, - size_t blank_id = 0); - -/** - * Store the root and prefixes for decoder - */ - -class CtcBeamSearchDecoderStorage { - public: - PathTrie *root = nullptr; - std::vector prefixes; - - CtcBeamSearchDecoderStorage() { - // init prefixes' root - this->root = new PathTrie(); - this->root->log_prob_b_prev = 0.0; - // The score of root is in log scale.Since the prob=1.0, the prob score - // in log scale is 0.0 - this->root->score = root->log_prob_b_prev; - // std::vector prefixes; - this->prefixes.push_back(root); - }; - - ~CtcBeamSearchDecoderStorage() { - if (root != nullptr) { - delete root; - root = nullptr; - } - }; -}; - -/** - * The ctc beam search decoder, support batchsize >= 1 - */ -class CtcBeamSearchDecoderBatch { - public: - CtcBeamSearchDecoderBatch(const std::vector &vocabulary, - size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - - ~CtcBeamSearchDecoderBatch(); - void next(const std::vector>> &probs_split, - const std::vector &has_value); - - std::vector>> decode(); - - void reset_state(size_t batch_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n); - - private: - std::vector vocabulary; - size_t batch_size; - size_t beam_size; - size_t num_processes; - double cutoff_prob; - size_t cutoff_top_n; - Scorer *ext_scorer; - size_t blank_id; - std::vector> - decoder_storage_vector; -}; - -/** - * function for chunk decoding - */ -void ctc_beam_search_decode_chunk( - PathTrie *root, - std::vector &prefixes, - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer, - size_t blank_id); - -std::vector> get_decode_result( - std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size, - Scorer *ext_scorer); - -/** - * free the CtcBeamSearchDecoderStorage - */ -void free_storage(std::unique_ptr &storage); - -/** - * initialize the root - */ -void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer); - -#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp deleted file mode 100644 index 6aa3c9964..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ctc_greedy_decoder.h" -#include "decoder_utils.h" - -std::string ctc_greedy_decoding( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t blank_id) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // size_t blank_id = vocabulary.size(); - - std::vector max_idx_vec(num_time_steps, 0); - std::vector idx_vec; - for (size_t i = 0; i < num_time_steps; ++i) { - double max_prob = 0.0; - size_t max_idx = 0; - const std::vector &probs_step = probs_seq[i]; - for (size_t j = 0; j < probs_step.size(); ++j) { - if (max_prob < probs_step[j]) { - max_idx = j; - max_prob = probs_step[j]; - } - } - // id with maximum probability in current time step - max_idx_vec[i] = max_idx; - // deduplicate - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { - idx_vec.push_back(max_idx_vec[i]); - } - } - - std::string best_path_result; - for (size_t i = 0; i < idx_vec.size(); ++i) { - if (idx_vec[i] != blank_id) { - std::string ch = vocabulary[idx_vec[i]]; - best_path_result += (ch == kSPACE) ? tSPACE : ch; - } - } - return best_path_result; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h b/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h deleted file mode 100644 index 4451600d6..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/ctc_greedy_decoder.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef CTC_GREEDY_DECODER_H -#define CTC_GREEDY_DECODER_H - -#include -#include - -/* CTC Greedy (Best Path) Decoder - * - * Parameters: - * probs_seq: 2-D vector that each element is a vector of probabilities - * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * Return: - * The decoding result in string - */ -std::string ctc_greedy_decoding( - const std::vector>& probs_seq, - const std::vector& vocabulary, - size_t blank_id); - -#endif // CTC_GREEDY_DECODER_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp deleted file mode 100644 index c7ef65428..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.cpp +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder_utils.h" - -#include -#include -#include - -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n) { - std::vector> prob_idx; - for (size_t i = 0; i < prob_step.size(); ++i) { - prob_idx.push_back(std::pair(i, prob_step[i])); - } - // pruning of vocabulary - size_t cutoff_len = prob_step.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { - std::sort(prob_idx.begin(), - prob_idx.end(), - pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob += prob_idx[i].second; - cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) - break; - } - } - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } - return log_prob_idx; -} - - -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (size_t j = 0; j < output.size(); j++) { - std::string ch = vocabulary[output[j]]; - output_str += (ch == kSPACE) ? tSPACE : ch; - } - std::pair output_pair( - -space_prefixes[i]->approx_ctc, output_str); - output_vecs.emplace_back(output_pair); - } - - return output_vecs; -} - -size_t get_utf8_str_len(const std::string &str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; -} - -std::vector split_utf8_str(const std::string &str) { - std::vector result; - std::string out_str; - - for (char c : str) { - if ((c & 0xc0) != 0x80) // new UTF-8 character - { - if (!out_str.empty()) { - result.push_back(out_str); - out_str.clear(); - } - } - - out_str.append(1, c); - } - result.push_back(out_str); - return result; -} - -std::vector split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; - } - return result; -} - -bool prefix_compare(const PathTrie *x, const PathTrie *y) { - if (x->score == y->score) { - if (x->character == y->character) { - return false; - } else { - return (x->character < y->character); - } - } else { - return x->score > y->score; - } -} - -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); -} - -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary) { - auto characters = split_utf8_str(word); - - std::vector int_word; - - for (auto &c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } - } - } - - if (add_space) { - int_word.push_back(SPACE_ID); - } - - add_word_to_fst(int_word, dictionary); - return true; // return with successful adding -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h b/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h deleted file mode 100644 index 098741552..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/decoder_utils.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DECODER_UTILS_H_ -#define DECODER_UTILS_H_ - -#include -#include -#include "fst/log.h" -#include "path_trie.h" - -const std::string kSPACE = ""; -const std::string tSPACE = " "; -const float NUM_FLT_INF = std::numeric_limits::max(); -const float NUM_FLT_MIN = std::numeric_limits::min(); - -// inline function for validation check -inline void check( - bool x, const char *expr, const char *file, int line, const char *err) { - if (!x) { - std::cout << "[" << file << ":" << line << "] "; - LOG(FATAL) << "\"" << expr << "\" check failed. " << err; - } -} - -#define VALID_CHECK(x, info) \ - check(static_cast(x), #x, __FILE__, __LINE__, info) -#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) -#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) -#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) - - -// Function template for comparing two pairs -template -bool pair_comp_first_rev(const std::pair &a, - const std::pair &b) { - return a.first > b.first; -} - -// Function template for comparing two pairs -template -bool pair_comp_second_rev(const std::pair &a, - const std::pair &b) { - return a.second > b.second; -} - -// Return the sum of two probabilities in log scale -template -T log_sum_exp(const T &x, const T &y) { - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; -} - -// Get pruned probability vector for each time step's beam search -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n); - -// Get beam search result from prefixes in trie tree -std::vector> get_beam_search_result( - const std::vector &prefixes, - const std::vector &vocabulary, - size_t beam_size); - -// Functor for prefix comparsion -bool prefix_compare(const PathTrie *x, const PathTrie *y); - -/* Get length of utf8 encoding string - * See: http://stackoverflow.com/a/4063229 - */ -size_t get_utf8_str_len(const std::string &str); - -/* Split a string into a list of strings on a given string - * delimiter. NB: delimiters on beginning / end of string are - * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. - */ -std::vector split_str(const std::string &s, - const std::string &delim); - -/* Splits string into vector of strings representing - * UTF-8 characters (not same as chars) - */ -std::vector split_utf8_str(const std::string &str); - -// Add a word in index to the dicionary of fst -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary); - -// Add a word in string to dictionary -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary); -#endif // DECODER_UTILS_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp deleted file mode 100644 index 777ca0520..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "path_trie.h" - -#include -#include -#include -#include -#include - -#include "decoder_utils.h" - -PathTrie::PathTrie() { - log_prob_b_prev = -NUM_FLT_INF; - log_prob_nb_prev = -NUM_FLT_INF; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - score = -NUM_FLT_INF; - - ROOT_ = -1; - character = ROOT_; - exists_ = true; - parent = nullptr; - - dictionary_ = nullptr; - dictionary_state_ = 0; - has_dictionary_ = false; - - matcher_ = nullptr; -} - -PathTrie::~PathTrie() { - for (auto child : children_) { - delete child.second; - child.second = nullptr; - } -} - -PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = children_.begin(); - for (child = children_.begin(); child != children_.end(); ++child) { - if (child->first == new_char) { - break; - } - } - if (child != children_.end()) { - if (!child->second->exists_) { - child->second->exists_ = true; - child->second->log_prob_b_prev = -NUM_FLT_INF; - child->second->log_prob_nb_prev = -NUM_FLT_INF; - child->second->log_prob_b_cur = -NUM_FLT_INF; - child->second->log_prob_nb_cur = -NUM_FLT_INF; - } - return (child->second); - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); - } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->dictionary_state_ = matcher_->Value().nextstate; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->parent = this; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } -} - -PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, ROOT_); -} - -PathTrie* PathTrie::get_path_vec(std::vector& output, - int stop, - size_t max_steps) { - if (character == stop || character == ROOT_ || output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - return this; - } else { - output.push_back(character); - return parent->get_path_vec(output, stop, max_steps); - } -} - -void PathTrie::iterate_to_vec(std::vector& output) { - if (exists_) { - log_prob_b_prev = log_prob_b_cur; - log_prob_nb_prev = log_prob_nb_cur; - - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - - score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - output.push_back(this); - } - for (auto child : children_) { - child.second->iterate_to_vec(output); - } -} - -void PathTrie::remove() { - exists_ = false; - if (children_.size() == 0) { - if (parent != nullptr) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); - child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; - } - } - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); - } - } - delete this; - } -} - - -void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - dictionary_ = dictionary; - dictionary_state_ = dictionary->Start(); - has_dictionary_ = true; -} - -using FSTMATCH = fst::SortedMatcher; -void PathTrie::set_matcher(std::shared_ptr matcher) { - matcher_ = matcher; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h b/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h deleted file mode 100644 index 5193e0a47..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/path_trie.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "COPYING.APACHE2.0"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef PATH_TRIE_H -#define PATH_TRIE_H - -#include -#include -#include -#include -#include - -#include "fst/fstlib.h" - -/* Trie tree for prefix storing and manipulating, with a dictionary in - * finite-state transducer for spelling correction. - */ -class PathTrie { - public: - PathTrie(); - ~PathTrie(); - - // get new prefix after appending new char - PathTrie* get_path_trie(int new_char, bool reset = true); - - // get the prefix in index from root to current node - PathTrie* get_path_vec(std::vector& output); - - // get the prefix in index from some stop node to current nodel - PathTrie* get_path_vec( - std::vector& output, - int stop, - size_t max_steps = std::numeric_limits::max()); - - // update log probs - void iterate_to_vec(std::vector& output); - - // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); - - void set_matcher(std::shared_ptr>); - - bool is_empty() { return ROOT_ == character; } - - // remove current path from root - void remove(); - - float log_prob_b_prev; - float log_prob_nb_prev; - float log_prob_b_cur; - float log_prob_nb_cur; - float score; - float approx_ctc; - int character; - PathTrie* parent; - - private: - int ROOT_; - bool exists_; - bool has_dictionary_; - - std::vector> children_; - - // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; - // true if finding ars in FST - std::shared_ptr> matcher_; -}; - -#endif // PATH_TRIE_H diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp b/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp deleted file mode 100644 index 6e7f68cf6..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.cpp +++ /dev/null @@ -1,232 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#include "scorer.h" - -#include -#include - -#include "lm/config.hh" -#include "lm/model.hh" -#include "lm/state.hh" - -#include "decoder_utils.h" - -using namespace lm::ngram; -// if your platform is windows ,you need add the define -#define F_OK 0 -Scorer::Scorer(double alpha, - double beta, - const std::string& lm_path, - const std::vector& vocab_list) { - this->alpha = alpha; - this->beta = beta; - - dictionary = nullptr; - is_character_based_ = true; - language_model_ = nullptr; - - max_order_ = 0; - dict_size_ = 0; - SPACE_ID_ = -1; - - setup(lm_path, vocab_list); -} - -Scorer::~Scorer() { - if (language_model_ != nullptr) { - delete static_cast(language_model_); - } - if (dictionary != nullptr) { - delete static_cast(dictionary); - } -} - -void Scorer::setup(const std::string& lm_path, - const std::vector& vocab_list) { - // load language model - load_lm(lm_path); - // set char map for scorer - set_char_map(vocab_list); - // fill the dictionary for FST - if (!is_character_based()) { - fill_dictionary(true); - } -} - -void Scorer::load_lm(const std::string& lm_path) { - const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); - - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - language_model_ = lm::ngram::LoadVirtual(filename, config); - max_order_ = static_cast(language_model_)->Order(); - vocabulary_ = enumerate.vocabulary; - for (size_t i = 0; i < vocabulary_.size(); ++i) { - if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && - vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && - get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - is_character_based_ = false; - } - } -} - -double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(language_model_); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; - } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; - } - // return log10 prob - return cond_prob; -} - -double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); - } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); -} - -double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > max_order_); - double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); - } - return score; -} - -void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; -} - -std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += char_list_[ind]; - } - return word; -} - -std::vector Scorer::split_labels(const std::vector& labels) { - if (labels.empty()) return {}; - - std::string s = vec2str(labels); - std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; -} - -void Scorer::set_char_map(const std::vector& char_list) { - char_list_ = char_list; - char_map_.clear(); - - // Set the char map for the FST for spelling correction - for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == kSPACE) { - SPACE_ID_ = i; - } - // The initial state of FST is state 0, hence the index of chars in - // the FST should start from 1 to avoid the conflict with the initial - // state, otherwise wrong decoding results would be given. - char_map_[char_list_[i]] = i + 1; - } -} - -std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; - - for (int order = 0; order < max_order_; order++) { - std::vector prefix_vec; - - if (is_character_based_) { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); - current_node = new_node->parent; // Skipping spaces - } - - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); - - if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; - } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; -} - -void Scorer::fill_dictionary(bool add_space) { - fst::StdVectorFst dictionary; - // For each unigram convert to ints and put in trie - int dict_size = 0; - for (const auto& word : vocabulary_) { - bool added = add_word_to_dictionary( - word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); - dict_size += added ? 1 : 0; - } - - dict_size_ = dict_size; - - /* Simplify FST - - * This gets rid of "epsilon" transitions in the FST. - * These are transitions that don't require a string input to be taken. - * Getting rid of them is necessary to make the FST deterministic, but - * can greatly increase the size of the FST - */ - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; - - /* This makes the FST deterministic, meaning for any string input there's - * only one possible state the FST could be in. It is assumed our - * dictionary is deterministic when using it. - * (lest we'd have to check for multiple transitions at each state) - */ - fst::Determinize(dictionary, new_dict); - - /* Finds the simplest equivalent fst. This is unnecessary but decreases - * memory usage of the dictionary - */ - fst::Minimize(new_dict); - this->dictionary = new_dict; -} diff --git a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h b/speechx/speechx/asr/decoder/ctc_decoders/scorer.h deleted file mode 100644 index 08e109b78..000000000 --- a/speechx/speechx/asr/decoder/ctc_decoders/scorer.h +++ /dev/null @@ -1,114 +0,0 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the -// "COPYING.LESSER.3"); - -#ifndef SCORER_H_ -#define SCORER_H_ - -#include -#include -#include -#include - -#include "lm/enumerate_vocab.hh" -#include "lm/virtual_interface.hh" -#include "lm/word_index.hh" - -#include "path_trie.h" - -const double OOV_SCORE = -1000.0; -const std::string START_TOKEN = ""; -const std::string UNK_TOKEN = ""; -const std::string END_TOKEN = ""; - -// Implement a callback to retrive the dictionary of language model. -class RetriveStrEnumerateVocab : public lm::EnumerateVocab { - public: - RetriveStrEnumerateVocab() {} - - void Add(lm::WordIndex index, const StringPiece &str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } - - std::vector vocabulary; -}; - -/* External scorer to query score for n-gram or sentence, including language - * model scoring and word insertion. - * - * Example: - * Scorer scorer(alpha, beta, "path_of_language_model"); - * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); - * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); - */ -class Scorer { - public: - Scorer(double alpha, - double beta, - const std::string &lm_path, - const std::vector &vocabulary); - ~Scorer(); - - double get_log_cond_prob(const std::vector &words); - - double get_sent_log_prob(const std::vector &words); - - // return the max order - size_t get_max_order() const { return max_order_; } - - // return the dictionary size of language model - size_t get_dict_size() const { return dict_size_; } - - // retrun true if the language model is character based - bool is_character_based() const { return is_character_based_; } - - // reset params alpha & beta - void reset_params(float alpha, float beta); - - // make ngram for a given prefix - std::vector make_ngram(PathTrie *prefix); - - // trransform the labels in index to the vector of words (word based lm) or - // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); - - // language model weight - double alpha; - // word insertion weight - double beta; - - // pointer to the dictionary of FST - void *dictionary; - - protected: - // necessary setup: load language model, set char map, fill FST's dictionary - void setup(const std::string &lm_path, - const std::vector &vocab_list); - - // load language model from given path - void load_lm(const std::string &lm_path); - - // fill dictionary for FST - void fill_dictionary(bool add_space); - - // set char map - void set_char_map(const std::vector &char_list); - - double get_log_prob(const std::vector &words); - - // translate the vector in index to string - std::string vec2str(const std::vector &input); - - private: - void *language_model_; - bool is_character_based_; - size_t max_order_; - size_t dict_size_; - - int SPACE_ID_; - std::vector char_list_; - std::unordered_map char_map_; - - std::vector vocabulary_; -}; - -#endif // SCORER_H_ diff --git a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc b/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc deleted file mode 100644 index e0acbe77b..000000000 --- a/speechx/speechx/asr/decoder/nnet_logprob_decoder_main.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor, repalce with gtest - -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" - -DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -// test decoder by feeding nnet posterior probability -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader likelihood_reader( - FLAGS_nnet_prob_respecifier); - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; - - int32 num_done = 0, num_err = 0; - - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); - - std::shared_ptr decodable( - new ppspeech::Decodable(nullptr, nullptr)); - - decoder.InitDecoder(); - - for (; !likelihood_reader.Done(); likelihood_reader.Next()) { - string utt = likelihood_reader.Key(); - const kaldi::Matrix likelihood = likelihood_reader.Value(); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << likelihood.NumRows(); - LOG(INFO) << "cols: " << likelihood.NumCols(); - decodable->Acceptlikelihood(likelihood); - decoder.AdvanceDecode(decodable); - std::string result; - result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; - decodable->Reset(); - decoder.Reset(); - ++num_done; - } - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/decoder/param.h b/speechx/speechx/asr/decoder/param.h index ebdd71197..cad6dbd8d 100644 --- a/speechx/speechx/asr/decoder/param.h +++ b/speechx/speechx/asr/decoder/param.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" +//#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 2846540ec..819cc2e89 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -1,30 +1,12 @@ set(srcs decodable.cc nnet_producer.cc) -if(USING_DS2) - list(APPEND srcs ds2_nnet.cc) -endif() - -if(USING_U2) - list(APPEND srcs u2_nnet.cc) -endif() +list(APPEND srcs u2_nnet.cc) add_library(nnet STATIC ${srcs}) target_link_libraries(nnet utils) -if(USING_U2) - target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) - target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -endif() - - -if(USING_DS2) - set(bin_name ds2_nnet_main) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) - - target_link_libraries(${bin_name} ${DEPS}) -endif() +target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) +target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) # test bin #if(USING_U2) diff --git a/speechx/speechx/asr/nnet/ds2_nnet.cc b/speechx/speechx/asr/nnet/ds2_nnet.cc deleted file mode 100644 index f77c0a603..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.cc +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "nnet/ds2_nnet.h" - -#include "utils/strings.h" - -namespace ppspeech { - -using kaldi::Matrix; -using kaldi::Vector; -using std::shared_ptr; -using std::string; -using std::vector; - -void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { - std::vector cache_names; - cache_names = StrSplit(opts.cache_names, ","); - std::vector cache_shapes; - cache_shapes = StrSplit(opts.cache_shape, ","); - assert(cache_shapes.size() == cache_names.size()); - - cache_encouts_.clear(); - cache_names_idx_.clear(); - for (size_t i = 0; i < cache_shapes.size(); i++) { - std::vector tmp_shape; - tmp_shape = StrSplit(cache_shapes[i], "-"); - std::vector cur_shape; - std::transform(tmp_shape.begin(), - tmp_shape.end(), - std::back_inserter(cur_shape), - [](const std::string& s) { return atoi(s.c_str()); }); - cache_names_idx_[cache_names[i]] = i; - std::shared_ptr> cache_eout = - std::make_shared>(cur_shape); - cache_encouts_.push_back(cache_eout); - } -} - -PaddleNnet::PaddleNnet(const ModelOptions& opts) : opts_(opts) { - subsampling_rate_ = opts.subsample_rate; - paddle_infer::Config config; - config.SetModel(opts.model_path, opts.param_path); - if (opts.use_gpu) { - config.EnableUseGpu(500, 0); - } - config.SwitchIrOptim(opts.switch_ir_optim); - if (opts.enable_fc_padding == false) { - config.DisableFCPadding(); - } - if (opts.enable_profile) { - config.EnableProfile(); - } - pool.reset( - new paddle_infer::services::PredictorPool(config, opts.thread_num)); - if (pool == nullptr) { - LOG(ERROR) << "create the predictor pool failed"; - } - pool_usages.resize(opts.thread_num); - std::fill(pool_usages.begin(), pool_usages.end(), false); - LOG(INFO) << "load paddle model success"; - - LOG(INFO) << "start to check the predictor input and output names"; - LOG(INFO) << "input names: " << opts.input_names; - LOG(INFO) << "output names: " << opts.output_names; - std::vector input_names_vec = StrSplit(opts.input_names, ","); - std::vector output_names_vec = StrSplit(opts.output_names, ","); - - paddle_infer::Predictor* predictor = GetPredictor(); - - std::vector model_input_names = predictor->GetInputNames(); - assert(input_names_vec.size() == model_input_names.size()); - for (size_t i = 0; i < model_input_names.size(); i++) { - assert(input_names_vec[i] == model_input_names[i]); - } - - std::vector model_output_names = predictor->GetOutputNames(); - assert(output_names_vec.size() == model_output_names.size()); - for (size_t i = 0; i < output_names_vec.size(); i++) { - assert(output_names_vec[i] == model_output_names[i]); - } - - ReleasePredictor(predictor); - InitCacheEncouts(opts); -} - -void PaddleNnet::Reset() { InitCacheEncouts(opts_); } - -paddle_infer::Predictor* PaddleNnet::GetPredictor() { - paddle_infer::Predictor* predictor = nullptr; - - std::lock_guard guard(pool_mutex); - int pred_id = 0; - - while (pred_id < pool_usages.size()) { - if (pool_usages[pred_id] == false) { - predictor = pool->Retrive(pred_id); - break; - } - ++pred_id; - } - - if (predictor) { - pool_usages[pred_id] = true; - predictor_to_thread_id[predictor] = pred_id; - } else { - LOG(INFO) << "Failed to get predictor from pool !!!"; - } - - return predictor; -} - -int PaddleNnet::ReleasePredictor(paddle_infer::Predictor* predictor) { - std::lock_guard guard(pool_mutex); - auto iter = predictor_to_thread_id.find(predictor); - - if (iter == predictor_to_thread_id.end()) { - LOG(INFO) << "there is no such predictor"; - return 0; - } - - pool_usages[iter->second] = false; - predictor_to_thread_id.erase(predictor); - return 0; -} - -shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { - auto iter = cache_names_idx_.find(name); - if (iter == cache_names_idx_.end()) { - return nullptr; - } - assert(iter->second < cache_encouts_.size()); - return cache_encouts_[iter->second]; -} - -void PaddleNnet::FeedForward(const Vector& features, - const int32& feature_dim, - NnetOut* out) { - paddle_infer::Predictor* predictor = GetPredictor(); - - int feat_row = features.Dim() / feature_dim; - - std::vector input_names = predictor->GetInputNames(); - std::vector output_names = predictor->GetOutputNames(); - - // feed inputs - std::unique_ptr input_tensor = - predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, feat_row, feature_dim}; - input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(features.Data()); - - std::unique_ptr input_len = - predictor->GetInputHandle(input_names[1]); - std::vector input_len_size = {1}; - input_len->Reshape(input_len_size); - std::vector audio_len; - audio_len.push_back(feat_row); - input_len->CopyFromCpu(audio_len.data()); - - std::unique_ptr state_h = - predictor->GetInputHandle(input_names[2]); - shared_ptr> h_cache = GetCacheEncoder(input_names[2]); - state_h->Reshape(h_cache->get_shape()); - state_h->CopyFromCpu(h_cache->get_data().data()); - - std::unique_ptr state_c = - predictor->GetInputHandle(input_names[3]); - shared_ptr> c_cache = GetCacheEncoder(input_names[3]); - state_c->Reshape(c_cache->get_shape()); - state_c->CopyFromCpu(c_cache->get_data().data()); - - // forward - bool success = predictor->Run(); - - if (success == false) { - LOG(INFO) << "predictor run occurs error"; - } - - // fetch outpus - std::unique_ptr h_out = - predictor->GetOutputHandle(output_names[2]); - assert(h_cache->get_shape() == h_out->shape()); - h_out->CopyToCpu(h_cache->get_data().data()); - - std::unique_ptr c_out = - predictor->GetOutputHandle(output_names[3]); - assert(c_cache->get_shape() == c_out->shape()); - c_out->CopyToCpu(c_cache->get_data().data()); - - std::unique_ptr output_tensor = - predictor->GetOutputHandle(output_names[0]); - std::vector output_shape = output_tensor->shape(); - int32 row = output_shape[1]; - int32 col = output_shape[2]; - - - // inferences->Resize(row * col); - // *inference_dim = col; - out->logprobs.Resize(row * col); - out->vocab_dim = col; - output_tensor->CopyToCpu(out->logprobs.Data()); - - ReleasePredictor(predictor); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/nnet/ds2_nnet.h b/speechx/speechx/asr/nnet/ds2_nnet.h deleted file mode 100644 index 420fa1771..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" -#include "nnet/nnet_itf.h" -#include "paddle_inference_api.h" - -namespace ppspeech { - - -template -class Tensor { - public: - Tensor() {} - explicit Tensor(const std::vector& shape) : _shape(shape) { - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - LOG(INFO) << "Tensor neml: " << neml; - _data.resize(neml, 0); - } - - void reshape(const std::vector& shape) { - _shape = shape; - int neml = std::accumulate( - _shape.begin(), _shape.end(), 1, std::multiplies()); - _data.resize(neml, 0); - } - - const std::vector& get_shape() const { return _shape; } - std::vector& get_data() { return _data; } - - private: - std::vector _shape; - std::vector _data; -}; - -class PaddleNnet : public NnetBase { - public: - explicit PaddleNnet(const ModelOptions& opts); - - void FeedForward(const kaldi::Vector& features, - const int32& feature_dim, - NnetOut* out) override; - - void AttentionRescoring(const std::vector>& hyps, - float reverse_weight, - std::vector* rescoring_score) override { - VLOG(2) << "deepspeech2 not has AttentionRescoring."; - } - - void Dim(); - - void Reset() override; - - bool IsLogProb() override { return false; } - - - std::shared_ptr> GetCacheEncoder( - const std::string& name); - - void InitCacheEncouts(const ModelOptions& opts); - - void EncoderOuts(std::vector>* encoder_out) - const override {} - - private: - paddle_infer::Predictor* GetPredictor(); - int ReleasePredictor(paddle_infer::Predictor* predictor); - - std::unique_ptr pool; - std::vector pool_usages; - std::mutex pool_mutex; - std::map predictor_to_thread_id; - std::map cache_names_idx_; - std::vector>> cache_encouts_; - - ModelOptions opts_; - - public: - DISALLOW_COPY_AND_ASSIGN(PaddleNnet); -}; - -} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/ds2_nnet_main.cc b/speechx/speechx/asr/nnet/ds2_nnet_main.cc deleted file mode 100644 index 6092b8a4c..000000000 --- a/speechx/speechx/asr/nnet/ds2_nnet_main.cc +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "base/common.h" -#include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" -#include "kaldi/util/table-types.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); -DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); - -using kaldi::BaseFloat; -using kaldi::Matrix; -using std::vector; - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - kaldi::SequentialBaseFloatMatrixReader feature_reader( - FLAGS_feature_rspecifier); - kaldi::BaseFloatMatrixWriter nnet_writer(FLAGS_nnet_prob_wspecifier); - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - LOG(INFO) << "model path: " << model_graph; - LOG(INFO) << "model param: " << model_params; - - int32 num_done = 0, num_err = 0; - - ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); - - std::shared_ptr nnet( - new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data(new ppspeech::DataCache()); - std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - - int32 chunk_size = FLAGS_receptive_field_length + - (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate; - int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk; - int32 receptive_field_length = FLAGS_receptive_field_length; - LOG(INFO) << "chunk size (frame): " << chunk_size; - LOG(INFO) << "chunk stride (frame): " << chunk_stride; - LOG(INFO) << "receptive field (frame): " << receptive_field_length; - kaldi::Timer timer; - for (; !feature_reader.Done(); feature_reader.Next()) { - string utt = feature_reader.Key(); - kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << feature.NumRows(); - LOG(INFO) << "cols: " << feature.NumCols(); - - int32 row_idx = 0; - int32 padding_len = 0; - int32 ori_feature_len = feature.NumRows(); - if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { - padding_len = - chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride; - feature.Resize(feature.NumRows() + padding_len, - feature.NumCols(), - kaldi::kCopyData); - } - int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1; - int32 frame_idx = 0; - std::vector> prob_vec; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - int32 feature_chunk_size = 0; - if (ori_feature_len > chunk_idx * chunk_stride) { - feature_chunk_size = std::min( - ori_feature_len - chunk_idx * chunk_stride, chunk_size); - } - if (feature_chunk_size < receptive_field_length) break; - - int32 start = chunk_idx * chunk_stride; - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, start); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - ++start; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - vector prob; - while (decodable->FrameLikelihood(frame_idx, &prob)) { - kaldi::Vector vec_tmp(prob.size()); - std::memcpy(vec_tmp.Data(), - prob.data(), - sizeof(kaldi::BaseFloat) * prob.size()); - prob_vec.push_back(vec_tmp); - frame_idx++; - } - } - decodable->Reset(); - if (prob_vec.size() == 0) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the nnet prob of " << utt << " is empty"; - continue; - } - kaldi::Matrix result(prob_vec.size(), - prob_vec[0].Dim()); - for (int row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { - for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { - result(row_idx, col_idx) = prob_vec[row_idx](col_idx); - } - } - - nnet_writer.Write(utt, result); - ++num_done; - } - - double elapsed = timer.Elapsed(); - KALDI_LOG << " cost:" << elapsed << " s"; - - KALDI_LOG << "Done " << num_done << " utterances, " << num_err - << " with errors."; - return (num_done != 0 ? 0 : 1); -} diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 3a0c4f188..955075913 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -65,7 +65,6 @@ bool NnetProducer::Compute() { size_t nframes = logprobs.Dim() / vocab_dim; VLOG(2) << "Forward out " << nframes << " decoder frames."; std::vector logprob(vocab_dim); - // remove later. for (size_t idx = 0; idx < nframes; ++idx) { for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 53e2e58db..e46593f50 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -1,46 +1,23 @@ set(srcs) -if (USING_DS2) list(APPEND srcs -recognizer.cc + u2_recognizer.cc ) -endif() - -if (USING_U2) - list(APPEND srcs - u2_recognizer.cc - ) -endif() add_library(recognizer STATIC ${srcs}) target_link_libraries(recognizer PUBLIC decoder) -# test -if (USING_DS2) - set(BINS recognizer_main) - - foreach(bin_name IN LISTS BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) - endforeach() -endif() - - -if (USING_U2) - set(TEST_BINS - u2_recognizer_main - u2_recognizer_thread_main - ) - - foreach(bin_name IN LISTS TEST_BINS) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) - endforeach() +set(TEST_BINS + u2_recognizer_main + u2_recognizer_thread_main +) -endif() +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +endforeach() diff --git a/speechx/speechx/asr/recognizer/recognizer.cc b/speechx/speechx/asr/recognizer/recognizer.cc deleted file mode 100644 index c66318131..000000000 --- a/speechx/speechx/asr/recognizer/recognizer.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "recognizer/recognizer.h" - - -namespace ppspeech { - -using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; -using std::unique_ptr; -using std::vector; - - -Recognizer::Recognizer(const RecognizerResource& resource) { - // resource_ = resource; - const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - - std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); - - BaseFloat ac_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); - - decoder_.reset(new TLGDecoder(resource.tlg_opts)); - - input_finished_ = false; -} - -void Recognizer::Accept(const Vector& waves) { - feature_pipeline_->Accept(waves); -} - -void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } - -std::string Recognizer::GetFinalResult() { - return decoder_->GetFinalBestPath(); -} - -std::string Recognizer::GetPartialResult() { - return decoder_->GetPartialResult(); -} - -void Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); - input_finished_ = true; -} - -bool Recognizer::IsFinished() { return input_finished_; } - -void Recognizer::Reset() { - feature_pipeline_->Reset(); - decodable_->Reset(); - decoder_->Reset(); -} - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer.h b/speechx/speechx/asr/recognizer/recognizer.h deleted file mode 100644 index 57d5bb363..000000000 --- a/speechx/speechx/asr/recognizer/recognizer.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// todo refactor later (SGoat) - -#pragma once - -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" -#include "frontend/audio/feature_pipeline.h" -#include "nnet/decodable.h" -#include "nnet/ds2_nnet.h" - -DECLARE_double(acoustic_scale); - -namespace ppspeech { - -struct RecognizerResource { - kaldi::BaseFloat acoustic_scale{1.0}; - FeaturePipelineOptions feature_pipeline_opts{}; - ModelOptions model_opts{}; - TLGDecoderOptions tlg_opts{}; - // CTCBeamSearchOptions beam_search_opts; - - static RecognizerResource InitFromFlags() { - RecognizerResource resource; - resource.acoustic_scale = FLAGS_acoustic_scale; - resource.feature_pipeline_opts = - FeaturePipelineOptions::InitFromFlags(); - resource.feature_pipeline_opts.assembler_opts.fill_zero = true; - LOG(INFO) << "ds2 need fill zero be true: " - << resource.feature_pipeline_opts.assembler_opts.fill_zero; - resource.model_opts = ModelOptions::InitFromFlags(); - resource.tlg_opts = TLGDecoderOptions::InitFromFlags(); - return resource; - } -}; - -class Recognizer { - public: - explicit Recognizer(const RecognizerResource& resouce); - void Accept(const kaldi::Vector& waves); - void Decode(); - std::string GetFinalResult(); - std::string GetPartialResult(); - void SetFinished(); - bool IsFinished(); - void Reset(); - - private: - // std::shared_ptr resource_; - // RecognizerResource resource_; - std::shared_ptr feature_pipeline_; - std::shared_ptr decodable_; - std::unique_ptr decoder_; - bool input_finished_; -}; - -} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/recognizer_main.cc b/speechx/speechx/asr/recognizer/recognizer_main.cc deleted file mode 100644 index cb0de2d6a..000000000 --- a/speechx/speechx/asr/recognizer/recognizer_main.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" -#include "kaldi/util/table-types.h" -#include "recognizer/recognizer.h" - -DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); -DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); -DEFINE_int32(sample_rate, 16000, "sample rate"); - - -int main(int argc, char* argv[]) { - gflags::SetUsageMessage("Usage:"); - gflags::ParseCommandLineFlags(&argc, &argv, false); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - FLAGS_logtostderr = 1; - - ppspeech::RecognizerResource resource = - ppspeech::RecognizerResource::InitFromFlags(); - ppspeech::Recognizer recognizer(resource); - - kaldi::SequentialTableReader wav_reader( - FLAGS_wav_rspecifier); - kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - int sample_rate = FLAGS_sample_rate; - float streaming_chunk = FLAGS_streaming_chunk; - int chunk_sample_size = streaming_chunk * sample_rate; - LOG(INFO) << "sr: " << sample_rate; - LOG(INFO) << "chunk size (s): " << streaming_chunk; - LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - - int32 num_done = 0, num_err = 0; - double tot_wav_duration = 0.0; - - kaldi::Timer timer; - - for (; !wav_reader.Done(); wav_reader.Next()) { - std::string utt = wav_reader.Key(); - const kaldi::WaveData& wave_data = wav_reader.Value(); - - int32 this_channel = 0; - kaldi::SubVector waveform(wave_data.Data(), - this_channel); - int tot_samples = waveform.Dim(); - tot_wav_duration += tot_samples * 1.0 / sample_rate; - LOG(INFO) << "wav len (sample): " << tot_samples; - - int sample_offset = 0; - std::vector> feats; - int feature_rows = 0; - while (sample_offset < tot_samples) { - int cur_chunk_size = - std::min(chunk_sample_size, tot_samples - sample_offset); - - kaldi::Vector wav_chunk(cur_chunk_size); - for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); - } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); - } - recognizer.Decode(); - - // no overlap - sample_offset += cur_chunk_size; - } - - std::string result; - result = recognizer.GetFinalResult(); - recognizer.Reset(); - if (result.empty()) { - // the TokenWriter can not write empty string. - ++num_err; - KALDI_LOG << " the result of " << utt << " is empty"; - continue; - } - KALDI_LOG << " the result of " << utt << " is " << result; - result_writer.Write(utt, result); - ++num_done; - } - double elapsed = timer.Elapsed(); - KALDI_LOG << "Done " << num_done << " out of " << (num_err + num_done); - KALDI_LOG << " cost:" << elapsed << " s"; - KALDI_LOG << "total wav duration is: " << tot_wav_duration << " s"; - KALDI_LOG << "the RTF is: " << elapsed / tot_wav_duration; -} diff --git a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc index 713c9ef1e..8c65b3465 100644 --- a/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc +++ b/speechx/speechx/common/frontend/audio/cmvn_json2kaldi_main.cc @@ -20,15 +20,12 @@ #include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/kaldi-io.h" #include "utils/file_utils.h" -// #include "boost/json.hpp" -#include +#include "utils/picojson.h" DEFINE_string(json_file, "", "cmvn json file"); DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn"); DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)"); -using namespace boost::json; // from - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,36 +37,49 @@ int main(int argc, char* argv[]) { auto ifs = std::ifstream(FLAGS_json_file); std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file); - auto value = boost::json::parse(json_str); - if (!value.is_object()) { + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { LOG(ERROR) << "Input json file format error."; } - for (auto obj : value.as_object()) { - if (obj.key() == "mean_stat") { - VLOG(2) << "mean_stat:" << obj.value(); + const picojson::value::object& obj = value.get(); + for (picojson::value::object::const_iterator elem = obj.begin(); + elem != obj.end(); + ++elem) { + if (elem->first == "mean_stat") { + VLOG(2) << "mean_stat:" << elem->second; + // const picojson::value tmp = + // elem->second.get(0);//(); + double tmp = + elem->second.get(0).get(); //(); + VLOG(2) << "tmp: " << tmp; } - if (obj.key() == "var_stat") { - VLOG(2) << "var_stat: " << obj.value(); + if (elem->first == "var_stat") { + VLOG(2) << "var_stat: " << elem->second; } - if (obj.key() == "frame_num") { - VLOG(2) << "frame_num: " << obj.value(); + if (elem->first == "frame_num") { + VLOG(2) << "frame_num: " << elem->second; } } - boost::json::array mean_stat = value.at("mean_stat").as_array(); + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); std::vector mean_stat_vec; for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { - mean_stat_vec.push_back(it->as_double()); + mean_stat_vec.push_back((*it).get()); } - boost::json::array var_stat = value.at("var_stat").as_array(); + const picojson::value::array& var_stat = + value.get("var_stat").get(); std::vector var_stat_vec; for (auto it = var_stat.begin(); it != var_stat.end(); it++) { - var_stat_vec.push_back(it->as_double()); + var_stat_vec.push_back((*it).get()); } - kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64()); + kaldi::int32 frame_num = value.get("frame_num").get(); LOG(INFO) << "nframe: " << frame_num; size_t mean_size = mean_stat_vec.size(); diff --git a/speechx/speechx/common/utils/picojson.h b/speechx/speechx/common/utils/picojson.h new file mode 100644 index 000000000..28c5b7fa8 --- /dev/null +++ b/speechx/speechx/common/utils/picojson.h @@ -0,0 +1,1202 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) \ + throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { +public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + +protected: + int type_; + _storage u_; + +public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template bool is() const; + template const T &get() const; + template T &get(); + template void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + +private: + template value(const T *); // intentionally defined to block implicit conversion of pointer to bool + template static void _indent(Iter os, int indent); + template void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() { +} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { + u_.boolean_ = b; +} + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { + u_.int64_ = i; +} +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { + clear(); +} + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } \ + template <> inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && (const_cast(this)->type_ = number_type, (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF(buf, sizeof(buf), fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 ? "%.f" : "%.17g", u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); i != u_.array_->end(); ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); i != u_.object_->end(); ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template class input { +protected: + Iter cur_, end_; + bool consumed_; + int line_; + +public: + input(const Iter &first, const Iter &last) : cur_(first), end_(last), consumed_(false), line_(1) { + } + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { + consumed_ = false; + } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { + return line_; + } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); pi != pattern.end(); ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back(static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && std::numeric_limits::min() <= ival && ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { +public: + bool set_null() { + return false; + } + bool set_bool(bool) { + return false; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return false; + } +#endif + bool set_number(double) { + return false; + } + template bool parse_string(input &) { + return false; + } + bool parse_array_start() { + return false; + } + template bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { + return false; + } + bool parse_object_start() { + return false; + } + template bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { +protected: + value *out_; + size_t depths_; + +public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) : out_(out), depths_(depths) { + } + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + *out_ = value(object_type, false); + return true; + } + template bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + +private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { +protected: + size_t depths_; + +public: + struct dummy_str { + void push_back(int) { + } + }; + +public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) { + } + bool set_null() { + return true; + } + bool set_bool(bool) { + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { + return true; + } +#endif + bool set_number(double) { + return true; + } + template bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) + return false; + --depths_; + return true; + } + template bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { + return true; + } + +private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template inline Iter _parse(Context &ctx, const Iter &first, const Iter &last, std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template inline Iter parse(value &out, const Iter &first, const Iter &last, std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, std::istreambuf_iterator(is.rdbuf()), std::istreambuf_iterator(), &err); + return err; +} + +template struct last_error_t { static std::string s; }; +template std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { + last_error_t::s = s; +} + +inline const std::string &get_last_error() { + return last_error_t::s; +} + +inline bool operator==(const value &x, const value &y) { + if (x.is()) + return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) \ + return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { + return !(x == y); +} +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file