From ad8ec177efe5ea92811611fbb483781105c26889 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Thu, 24 Mar 2022 10:51:59 +0800 Subject: [PATCH 1/9] add tlg decoder --- speechx/examples/decoder/CMakeLists.txt | 4 + .../decoder/offline_wfst_decoder_main.cc | 101 + speechx/speechx/decoder/CMakeLists.txt | 3 +- .../speechx/decoder/ctc_beam_search_decoder.h | 2 +- speechx/speechx/decoder/ctc_tlg_decoder.cc | 50 + speechx/speechx/decoder/ctc_tlg_decoder.h | 47 + speechx/speechx/kaldi/CMakeLists.txt | 3 + speechx/speechx/kaldi/decoder/CMakeLists.txt | 6 + .../{nnet => kaldi/decoder}/decodable-itf.h | 2 +- .../kaldi/decoder/lattice-faster-decoder.cc | 4 - .../kaldi/decoder/lattice-faster-decoder.h | 3 +- .../decoder/lattice-faster-online-decoder.cc | 4 +- .../decoder/lattice-faster-online-decoder.h | 2 +- speechx/speechx/kaldi/fstext/CMakeLists.txt | 5 + .../kaldi/fstext/determinize-lattice-inl.h | 1357 ++++++ .../kaldi/fstext/determinize-lattice.h | 144 + .../kaldi/fstext/determinize-star-inl.h | 1204 ++++++ .../speechx/kaldi/fstext/determinize-star.h | 116 + speechx/speechx/kaldi/fstext/fstext-lib.h | 34 + .../speechx/kaldi/fstext/fstext-utils-inl.h | 1265 ++++++ speechx/speechx/kaldi/fstext/fstext-utils.h | 386 ++ .../speechx/kaldi/fstext/kaldi-fst-io-inl.h | 208 + speechx/speechx/kaldi/fstext/kaldi-fst-io.cc | 148 + speechx/speechx/kaldi/fstext/kaldi-fst-io.h | 158 + .../speechx/kaldi/fstext/lattice-utils-inl.h | 267 ++ speechx/speechx/kaldi/fstext/lattice-utils.h | 259 ++ speechx/speechx/kaldi/fstext/lattice-weight.h | 892 ++++ .../kaldi/fstext/pre-determinize-inl.h | 798 ++++ .../speechx/kaldi/fstext/pre-determinize.h | 98 + .../kaldi/fstext/remove-eps-local-inl.h | 318 ++ .../speechx/kaldi/fstext/remove-eps-local.h | 57 + speechx/speechx/kaldi/fstext/table-matcher.h | 387 ++ speechx/speechx/kaldi/lat/CMakeLists.txt | 6 + .../lat/determinize-lattice-pruned-test.cc | 147 - .../kaldi/lat/determinize-lattice-pruned.cc | 468 +-- .../kaldi/lat/determinize-lattice-pruned.h | 156 +- speechx/speechx/kaldi/lat/kaldi-lattice.h | 16 +- .../speechx/kaldi/lat/lattice-functions.cc | 3632 +++++++++-------- speechx/speechx/kaldi/lat/lattice-functions.h | 767 ++-- speechx/speechx/nnet/decodable.cc | 12 +- speechx/speechx/nnet/decodable.h | 6 +- 41 files changed, 10943 insertions(+), 2599 deletions(-) create mode 100644 speechx/examples/decoder/offline_wfst_decoder_main.cc create mode 100644 speechx/speechx/decoder/ctc_tlg_decoder.cc create mode 100644 speechx/speechx/decoder/ctc_tlg_decoder.h create mode 100644 speechx/speechx/kaldi/decoder/CMakeLists.txt rename speechx/speechx/{nnet => kaldi/decoder}/decodable-itf.h (99%) create mode 100644 speechx/speechx/kaldi/fstext/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/fstext/determinize-lattice-inl.h create mode 100644 speechx/speechx/kaldi/fstext/determinize-lattice.h create mode 100644 speechx/speechx/kaldi/fstext/determinize-star-inl.h create mode 100644 speechx/speechx/kaldi/fstext/determinize-star.h create mode 100644 speechx/speechx/kaldi/fstext/fstext-lib.h create mode 100644 speechx/speechx/kaldi/fstext/fstext-utils-inl.h create mode 100644 speechx/speechx/kaldi/fstext/fstext-utils.h create mode 100644 speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h create mode 100644 speechx/speechx/kaldi/fstext/kaldi-fst-io.cc create mode 100644 speechx/speechx/kaldi/fstext/kaldi-fst-io.h create mode 100644 speechx/speechx/kaldi/fstext/lattice-utils-inl.h create mode 100644 speechx/speechx/kaldi/fstext/lattice-utils.h create mode 100644 speechx/speechx/kaldi/fstext/lattice-weight.h create mode 100644 speechx/speechx/kaldi/fstext/pre-determinize-inl.h create mode 100644 speechx/speechx/kaldi/fstext/pre-determinize.h create mode 100644 speechx/speechx/kaldi/fstext/remove-eps-local-inl.h create mode 100644 speechx/speechx/kaldi/fstext/remove-eps-local.h create mode 100644 speechx/speechx/kaldi/fstext/table-matcher.h create mode 100644 speechx/speechx/kaldi/lat/CMakeLists.txt delete mode 100644 speechx/speechx/kaldi/lat/determinize-lattice-pruned-test.cc diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index 4bd5c6cf..11e2ca91 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -3,3 +3,7 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) + +add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc) +target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) diff --git a/speechx/examples/decoder/offline_wfst_decoder_main.cc b/speechx/examples/decoder/offline_wfst_decoder_main.cc new file mode 100644 index 00000000..7c1f6226 --- /dev/null +++ b/speechx/examples/decoder/offline_wfst_decoder_main.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor, repalce with gtest + +#include "base/flags.h" +#include "base/log.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/raw_audio.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); + + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + kaldi::SequentialBaseFloatMatrixReader feature_reader( + FLAGS_feature_respecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string word_symbol_table = FLAGS_word_symbol_table; + std::string graph_path = FLAGS_graph_path; + + int32 num_done = 0, num_err = 0; + + ppspeech::TLGDecoderOptions opts; + opts.word_symbol_table = word_symbol_table; + opts.fst_path = graph_path; + ppspeech::TLGDecoder decoder(opts); + + ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; + std::shared_ptr nnet( + new ppspeech::PaddleNnet(model_opts)); + std::shared_ptr raw_data( + new ppspeech::RawDataCache()); + std::shared_ptr decodable( + new ppspeech::Decodable(nnet, raw_data)); + + int32 chunk_size = 35; + decoder.InitDecoder(); + + for (; !feature_reader.Done(); feature_reader.Next()) { + string utt = feature_reader.Key(); + const kaldi::Matrix feature = feature_reader.Value(); + raw_data->SetDim(feature.NumCols()); + int32 row_idx = 0; + int32 num_chunks = feature.NumRows() / chunk_size; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, row_idx); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + row_idx++; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } + std::string result; + result = decoder.GetFinalBestPath(); + KALDI_LOG << " the result of " << utt << " is " << result; + decodable->Reset(); + decoder.Reset(); + ++num_done; + } + + KALDI_LOG << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index 7cd281b6..ee0863fd 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -6,5 +6,6 @@ add_library(decoder STATIC ctc_decoders/decoder_utils.cpp ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp + ctc_tlg_decoder.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) \ No newline at end of file +target_link_libraries(decoder PUBLIC kenlm utils fst) diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 451f33c0..cf1824c6 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -15,7 +15,7 @@ #include "base/common.h" #include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/scorer.h" -#include "nnet/decodable-itf.h" +#include "kaldi/decoder/decodable-itf.h" #include "util/parse-options.h" #pragma once diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc new file mode 100644 index 00000000..c08a7d5b --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -0,0 +1,50 @@ +#include "decoder/ctc_tlg_decoder.h" +namespace ppspeech { + +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { + fst_.reset(fst::Fst::Read(opts.fst_path)); + CHECK(fst_ != nullptr); + word_symbol_table_.reset(fst::SymbolTable::ReadText(opts.word_symbol_table)); + decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); + decoder_->InitDecoding(); +} + +void TLGDecoder::InitDecoder() { + decoder_->InitDecoding(); +} + +void TLGDecoder::AdvanceDecode(const std::shared_ptr& decodable) { + while (1) { + AdvanceDecoding(decodable.get()); + if (decodable->IsLastFrame(num_frame_decoded_)) break; + } +} + +void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { + // skip blank frame? + decoder_->AdvanceDecoding(decodable, 1); + num_frame_decoded_++; +} + +void TLGDecoder::Reset() { + decoder_->InitDecoding(); + return; +} + +std::string TLGDecoder::GetFinalBestPath() { + decoder_->FinalizeDecoding(); + kaldi::Lattice lat; + kaldi::LatticeWeight weight; + std::vector alignment; + std::vector words_id; + decoder_->GetBestPath(&lat, true); + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + std::string words; + for (int32 idx = 0; idx < words_id.size(); ++idx) { + std::string word = word_symbol_table_->Find(words_id[idx]); + words += word; + } + return words; +} + +} \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/speechx/speechx/decoder/ctc_tlg_decoder.h new file mode 100644 index 00000000..b4cd8c34 --- /dev/null +++ b/speechx/speechx/decoder/ctc_tlg_decoder.h @@ -0,0 +1,47 @@ +#pragma once + +#include "kaldi/decoder/lattice-faster-online-decoder.h" +#include "kaldi/decoder/decodable-itf.h" +#include "util/parse-options.h" +#include "base/basic_types.h" + +namespace ppspeech { + +struct TLGDecoderOptions { + kaldi::LatticeFasterDecoderConfig opts; + // todo remove later, add into decode resource + std::string word_symbol_table; + std::string fst_path; + + TLGDecoderOptions() + : word_symbol_table(""), + fst_path("") {} +}; + +class TLGDecoder { + public: + explicit TLGDecoder(TLGDecoderOptions opts); + void InitDecoder(); + void Decode(); + std::string GetBestPath(); + std::vector> GetNBestPath(); + std::string GetFinalBestPath(); + int NumFrameDecoded(); + int DecodeLikelihoods(const std::vector>& probs, + std::vector& nbest_words); + void AdvanceDecode( + const std::shared_ptr& decodable); + void Reset(); + + private: + void AdvanceDecoding(kaldi::DecodableInterface* decodable); + + std::shared_ptr decoder_; + std::shared_ptr> fst_; + std::shared_ptr word_symbol_table_; + int32 num_frame_decoded_; + }; + + + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt index 414a6fa0..6f7398cd 100644 --- a/speechx/speechx/kaldi/CMakeLists.txt +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -4,3 +4,6 @@ add_subdirectory(base) add_subdirectory(util) add_subdirectory(feat) add_subdirectory(matrix) +add_subdirectory(lat) +add_subdirectory(fstext) +add_subdirectory(decoder) diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/speechx/speechx/kaldi/decoder/CMakeLists.txt new file mode 100644 index 00000000..f1ee6eab --- /dev/null +++ b/speechx/speechx/kaldi/decoder/CMakeLists.txt @@ -0,0 +1,6 @@ + +add_library(kaldi-decoder +lattice-faster-decoder.cc +lattice-faster-online-decoder.cc +) +target_link_libraries(kaldi-decoder PUBLIC kaldi-lat) diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/kaldi/decoder/decodable-itf.h similarity index 99% rename from speechx/speechx/nnet/decodable-itf.h rename to speechx/speechx/kaldi/decoder/decodable-itf.h index 8e9a5a72..19e07498 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/kaldi/decoder/decodable-itf.h @@ -121,7 +121,7 @@ class DecodableInterface { /// decoding-from-matrix setting where we want to allow the last delta or /// LDA /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; + virtual bool IsLastFrame(int32 frame) = 0; /// The call NumFramesReady() will return the number of frames currently /// available diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc index 42d1d2af..ae6b7160 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc @@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl, decoder::StdToken> template class LatticeFasterDecoderTpl, decoder::StdToken >; template class LatticeFasterDecoderTpl, decoder::StdToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; template class LatticeFasterDecoderTpl , decoder::BackpointerToken>; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; template class LatticeFasterDecoderTpl, decoder::BackpointerToken >; -template class LatticeFasterDecoderTpl; -template class LatticeFasterDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h index 2016ad57..d142a8c7 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h @@ -23,11 +23,10 @@ #ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_ #define KALDI_DECODER_LATTICE_FASTER_DECODER_H_ -#include "decoder/grammar-fst.h" #include "fst/fstlib.h" #include "fst/memory.h" #include "fstext/fstext-lib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" #include "util/hash-list.h" diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc index ebdace7e..b5261503 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc @@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl::GetRawLatticePruned( template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; template class LatticeFasterOnlineDecoderTpl >; -template class LatticeFasterOnlineDecoderTpl; -template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; +//template class LatticeFasterOnlineDecoderTpl; } // end namespace kaldi. diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h index 8b10996f..f57368a4 100644 --- a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h +++ b/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h @@ -30,7 +30,7 @@ #include "util/stl-utils.h" #include "util/hash-list.h" #include "fst/fstlib.h" -#include "itf/decodable-itf.h" +#include "decoder/decodable-itf.h" #include "fstext/fstext-lib.h" #include "lat/determinize-lattice-pruned.h" #include "lat/kaldi-lattice.h" diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/speechx/speechx/kaldi/fstext/CMakeLists.txt new file mode 100644 index 00000000..af91fd98 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/CMakeLists.txt @@ -0,0 +1,5 @@ + +add_library(kaldi-fstext +kaldi-fst-io.cc +) +target_link_libraries(kaldi-fstext PUBLIC kaldi-util) diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h new file mode 100644 index 00000000..0bfbc8f4 --- /dev/null +++ b/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h @@ -0,0 +1,1357 @@ +// fstext/determinize-lattice-inl.h + +// Copyright 2009-2012 Microsoft Corporation +// 2012-2013 Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ +// Do not include this file directly. It is included by determinize-lattice.h + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// This class maps back and forth from/to integer id's to sequences of strings. +// used in determinization algorithm. It is constructed in such a way that +// finding the string-id of the successor of (string, next-label) has constant +// time. + +// Note: class IntType, typically int32, is the type of the element in the +// string (typically a template argument of the CompactLatticeWeightTpl). + +template +class LatticeStringRepository { + public: + struct Entry { + const Entry *parent; // NULL for empty string. + IntType i; + inline bool operator==(const Entry &other) const { + return (parent == other.parent && i == other.i); + } + Entry() {} + Entry(const Entry &e) : parent(e.parent), i(e.i) {} + }; + // Note: all Entry* pointers returned in function calls are + // owned by the repository itself, not by the caller! + + // Interface guarantees empty string is NULL. + inline const Entry *EmptyString() { return NULL; } + + // Returns string of "parent" with i appended. Pointer + // owned by repository + const Entry *Successor(const Entry *parent, IntType i) { + new_entry_->parent = parent; + new_entry_->i = i; + + std::pair pr = set_.insert(new_entry_); + if (pr.second) { // Was successfully inserted (was not there). We need to + // replace the element we inserted, which resides on the + // stack, with one from the heap. + const Entry *ans = new_entry_; + new_entry_ = new Entry(); + return ans; + } else { // Was not inserted because an equivalent Entry already + // existed. + return *pr.first; + } + } + + const Entry *Concatenate(const Entry *a, const Entry *b) { + if (a == NULL) + return b; + else if (b == NULL) + return a; + std::vector v; + ConvertToVector(b, &v); + const Entry *ans = a; + for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]); + return ans; + } + const Entry *CommonPrefix(const Entry *a, const Entry *b) { + std::vector a_vec, b_vec; + ConvertToVector(a, &a_vec); + ConvertToVector(b, &b_vec); + const Entry *ans = NULL; + for (size_t i = 0; + i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++) + ans = Successor(ans, a_vec[i]); + return ans; + } + + // removes any elements from b that are not part of + // a common prefix with a. + void ReduceToCommonPrefix(const Entry *a, std::vector *b) { + size_t a_size = Size(a), b_size = b->size(); + while (a_size > b_size) { + a = a->parent; + a_size--; + } + if (b_size > a_size) b_size = a_size; + typename std::vector::iterator b_begin = b->begin(); + while (a_size != 0) { + if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1; + a = a->parent; + a_size--; + } + if (b_size != b->size()) b->resize(b_size); + } + + // removes the first n elements of a. + const Entry *RemovePrefix(const Entry *a, size_t n) { + if (n == 0) return a; + std::vector a_vec; + ConvertToVector(a, &a_vec); + assert(a_vec.size() >= n); + const Entry *ans = NULL; + for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]); + return ans; + } + + // Returns true if a is a prefix of b. If a is prefix of b, + // time taken is |b| - |a|. Else, time taken is |b|. + bool IsPrefixOf(const Entry *a, const Entry *b) const { + if (a == NULL) return true; // empty string prefix of all. + if (a == b) return true; + if (b == NULL) return false; + return IsPrefixOf(a, b->parent); + } + + inline size_t Size(const Entry *entry) const { + size_t ans = 0; + while (entry != NULL) { + ans++; + entry = entry->parent; + } + return ans; + } + + void ConvertToVector(const Entry *entry, std::vector *out) const { + size_t length = Size(entry); + out->resize(length); + if (entry != NULL) { + typename std::vector::reverse_iterator iter = out->rbegin(); + while (entry != NULL) { + *iter = entry->i; + entry = entry->parent; + ++iter; + } + } + } + + const Entry *ConvertFromVector(const std::vector &vec) { + const Entry *e = NULL; + for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]); + return e; + } + + LatticeStringRepository() { new_entry_ = new Entry; } + + void Destroy() { + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) + delete *iter; + SetType tmp; + tmp.swap(set_); + if (new_entry_) { + delete new_entry_; + new_entry_ = NULL; + } + } + + // Rebuild will rebuild this object, guaranteeing only + // to preserve the Entry values that are in the vector pointed + // to (this list does not have to be unique). The point of + // this is to save memory. + void Rebuild(const std::vector &to_keep) { + SetType tmp_set; + for (typename std::vector::const_iterator iter = + to_keep.begin(); + iter != to_keep.end(); ++iter) + RebuildHelper(*iter, &tmp_set); + // Now delete all elems not in tmp_set. + for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); + ++iter) { + if (tmp_set.count(*iter) == 0) + delete (*iter); // delete the Entry; not needed. + } + set_.swap(tmp_set); + } + + ~LatticeStringRepository() { Destroy(); } + int32 MemSize() const { + return set_.size() * sizeof(Entry) * 2; // this is a lower bound + // on the size this structure might take. + } + + private: + class EntryKey { // Hash function object. + public: + inline size_t operator()(const Entry *entry) const { + size_t prime = 49109; + return static_cast(entry->i) + + prime * reinterpret_cast(entry->parent); + } + }; + class EntryEqual { + public: + inline bool operator()(const Entry *e1, const Entry *e2) const { + return (*e1 == *e2); + } + }; + typedef std::unordered_set SetType; + + void RebuildHelper(const Entry *to_add, SetType *tmp_set) { + while (true) { + if (to_add == NULL) return; + typename SetType::iterator iter = tmp_set->find(to_add); + if (iter == tmp_set->end()) { // not in tmp_set. + tmp_set->insert(to_add); + to_add = to_add->parent; // and loop. + } else { + return; + } + } + } + + KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository); + Entry *new_entry_; // We always have a pre-allocated Entry ready to use, + // to avoid unnecessary news and deletes. + SetType set_; +}; + +// class LatticeDeterminizer is templated on the same types that +// CompactLatticeWeight is templated on: the base weight (Weight), typically +// LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the +// IntType, typically int32, used for the output symbols in the compact +// representation of strings [note: the output symbols would usually be +// p.d.f. id's in the anticipated use of this code] It has a special requirement +// on the Weight type: that there should be a Compare function on the weights +// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 +// > w2. This requires that there be a total order on the weights. + +template +class LatticeDeterminizer { + public: + // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 + // correspondence between our states and the states in ofst. If destroy == + // true, release memory as we go (but we cannot output again). + + typedef CompactLatticeWeightTpl CompactWeight; + typedef ArcTpl + CompactArc; // arc in compact, acceptor form of lattice + typedef ArcTpl Arc; // arc in non-compact version of lattice + + // Output to standard FST with CompactWeightTpl as its weight type + // (the weight stores the original output-symbol strings). If destroy == + // true, release memory as we go (but we cannot output again). + void Output(MutableFst *ofst, bool destroy = true) { + assert(determinized_); + typedef typename Arc::StateId StateId; + StateId nStates = static_cast(output_arcs_.size()); + if (destroy) FreeMostMemory(); + ofst->DeleteStates(); + ofst->SetStart(kNoStateId); + if (nStates == 0) { + return; + } + for (StateId s = 0; s < nStates; s++) { + OutputStateId news = ofst->AddState(); + assert(news == s); + } + ofst->SetStart(0); + // now process transitions. + for (StateId this_state = 0; this_state < nStates; this_state++) { + std::vector &this_vec(output_arcs_[this_state]); + typename std::vector::const_iterator iter = this_vec.begin(), + end = this_vec.end(); + + for (; iter != end; ++iter) { + const TempArc &temp_arc(*iter); + CompactArc new_arc; + std::vector