// fstext/determinize-lattice-inl.h // Copyright 2009-2012 Microsoft Corporation // 2012-2013 Johns Hopkins University (Author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ #define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_ // Do not include this file directly. It is included by determinize-lattice.h #include #include #include #include #include #include #include #include namespace fst { // This class maps back and forth from/to integer id's to sequences of strings. // used in determinization algorithm. It is constructed in such a way that // finding the string-id of the successor of (string, next-label) has constant // time. // Note: class IntType, typically int32, is the type of the element in the // string (typically a template argument of the CompactLatticeWeightTpl). template class LatticeStringRepository { public: struct Entry { const Entry *parent; // NULL for empty string. IntType i; inline bool operator==(const Entry &other) const { return (parent == other.parent && i == other.i); } Entry() {} Entry(const Entry &e) : parent(e.parent), i(e.i) {} }; // Note: all Entry* pointers returned in function calls are // owned by the repository itself, not by the caller! // Interface guarantees empty string is NULL. inline const Entry *EmptyString() { return NULL; } // Returns string of "parent" with i appended. Pointer // owned by repository const Entry *Successor(const Entry *parent, IntType i) { new_entry_->parent = parent; new_entry_->i = i; std::pair pr = set_.insert(new_entry_); if (pr.second) { // Was successfully inserted (was not there). We need to // replace the element we inserted, which resides on the // stack, with one from the heap. const Entry *ans = new_entry_; new_entry_ = new Entry(); return ans; } else { // Was not inserted because an equivalent Entry already // existed. return *pr.first; } } const Entry *Concatenate(const Entry *a, const Entry *b) { if (a == NULL) return b; else if (b == NULL) return a; std::vector v; ConvertToVector(b, &v); const Entry *ans = a; for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]); return ans; } const Entry *CommonPrefix(const Entry *a, const Entry *b) { std::vector a_vec, b_vec; ConvertToVector(a, &a_vec); ConvertToVector(b, &b_vec); const Entry *ans = NULL; for (size_t i = 0; i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++) ans = Successor(ans, a_vec[i]); return ans; } // removes any elements from b that are not part of // a common prefix with a. void ReduceToCommonPrefix(const Entry *a, std::vector *b) { size_t a_size = Size(a), b_size = b->size(); while (a_size > b_size) { a = a->parent; a_size--; } if (b_size > a_size) b_size = a_size; typename std::vector::iterator b_begin = b->begin(); while (a_size != 0) { if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1; a = a->parent; a_size--; } if (b_size != b->size()) b->resize(b_size); } // removes the first n elements of a. const Entry *RemovePrefix(const Entry *a, size_t n) { if (n == 0) return a; std::vector a_vec; ConvertToVector(a, &a_vec); assert(a_vec.size() >= n); const Entry *ans = NULL; for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]); return ans; } // Returns true if a is a prefix of b. If a is prefix of b, // time taken is |b| - |a|. Else, time taken is |b|. bool IsPrefixOf(const Entry *a, const Entry *b) const { if (a == NULL) return true; // empty string prefix of all. if (a == b) return true; if (b == NULL) return false; return IsPrefixOf(a, b->parent); } inline size_t Size(const Entry *entry) const { size_t ans = 0; while (entry != NULL) { ans++; entry = entry->parent; } return ans; } void ConvertToVector(const Entry *entry, std::vector *out) const { size_t length = Size(entry); out->resize(length); if (entry != NULL) { typename std::vector::reverse_iterator iter = out->rbegin(); while (entry != NULL) { *iter = entry->i; entry = entry->parent; ++iter; } } } const Entry *ConvertFromVector(const std::vector &vec) { const Entry *e = NULL; for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]); return e; } LatticeStringRepository() { new_entry_ = new Entry; } void Destroy() { for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); ++iter) delete *iter; SetType tmp; tmp.swap(set_); if (new_entry_) { delete new_entry_; new_entry_ = NULL; } } // Rebuild will rebuild this object, guaranteeing only // to preserve the Entry values that are in the vector pointed // to (this list does not have to be unique). The point of // this is to save memory. void Rebuild(const std::vector &to_keep) { SetType tmp_set; for (typename std::vector::const_iterator iter = to_keep.begin(); iter != to_keep.end(); ++iter) RebuildHelper(*iter, &tmp_set); // Now delete all elems not in tmp_set. for (typename SetType::iterator iter = set_.begin(); iter != set_.end(); ++iter) { if (tmp_set.count(*iter) == 0) delete (*iter); // delete the Entry; not needed. } set_.swap(tmp_set); } ~LatticeStringRepository() { Destroy(); } int32 MemSize() const { return set_.size() * sizeof(Entry) * 2; // this is a lower bound // on the size this structure might take. } private: class EntryKey { // Hash function object. public: inline size_t operator()(const Entry *entry) const { size_t prime = 49109; return static_cast(entry->i) + prime * reinterpret_cast(entry->parent); } }; class EntryEqual { public: inline bool operator()(const Entry *e1, const Entry *e2) const { return (*e1 == *e2); } }; typedef std::unordered_set SetType; void RebuildHelper(const Entry *to_add, SetType *tmp_set) { while (true) { if (to_add == NULL) return; typename SetType::iterator iter = tmp_set->find(to_add); if (iter == tmp_set->end()) { // not in tmp_set. tmp_set->insert(to_add); to_add = to_add->parent; // and loop. } else { return; } } } KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository); Entry *new_entry_; // We always have a pre-allocated Entry ready to use, // to avoid unnecessary news and deletes. SetType set_; }; // class LatticeDeterminizer is templated on the same types that // CompactLatticeWeight is templated on: the base weight (Weight), typically // LatticeWeightTpl etc. but could also be e.g. TropicalWeight, and the // IntType, typically int32, used for the output symbols in the compact // representation of strings [note: the output symbols would usually be // p.d.f. id's in the anticipated use of this code] It has a special requirement // on the Weight type: that there should be a Compare function on the weights // such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 // > w2. This requires that there be a total order on the weights. template class LatticeDeterminizer { public: // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 // correspondence between our states and the states in ofst. If destroy == // true, release memory as we go (but we cannot output again). typedef CompactLatticeWeightTpl CompactWeight; typedef ArcTpl CompactArc; // arc in compact, acceptor form of lattice typedef ArcTpl Arc; // arc in non-compact version of lattice // Output to standard FST with CompactWeightTpl as its weight type // (the weight stores the original output-symbol strings). If destroy == // true, release memory as we go (but we cannot output again). void Output(MutableFst *ofst, bool destroy = true) { assert(determinized_); typedef typename Arc::StateId StateId; StateId nStates = static_cast(output_arcs_.size()); if (destroy) FreeMostMemory(); ofst->DeleteStates(); ofst->SetStart(kNoStateId); if (nStates == 0) { return; } for (StateId s = 0; s < nStates; s++) { OutputStateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(0); // now process transitions. for (StateId this_state = 0; this_state < nStates; this_state++) { std::vector &this_vec(output_arcs_[this_state]); typename std::vector::const_iterator iter = this_vec.begin(), end = this_vec.end(); for (; iter != end; ++iter) { const TempArc &temp_arc(*iter); CompactArc new_arc; std::vector