You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1358 lines
54 KiB
1358 lines
54 KiB
// fstext/determinize-lattice-inl.h
|
|
|
|
// Copyright 2009-2012 Microsoft Corporation
|
|
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
|
|
|
// See ../../COPYING for clarification regarding multiple authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
|
// See the Apache 2 License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
|
|
#define KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
|
|
// Do not include this file directly. It is included by determinize-lattice.h
|
|
|
|
#include <algorithm>
|
|
#include <climits>
|
|
#include <deque>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace fst {
|
|
|
|
// This class maps back and forth from/to integer id's to sequences of strings.
|
|
// used in determinization algorithm. It is constructed in such a way that
|
|
// finding the string-id of the successor of (string, next-label) has constant
|
|
// time.
|
|
|
|
// Note: class IntType, typically int32, is the type of the element in the
|
|
// string (typically a template argument of the CompactLatticeWeightTpl).
|
|
|
|
template <class IntType>
|
|
class LatticeStringRepository {
|
|
public:
|
|
struct Entry {
|
|
const Entry *parent; // NULL for empty string.
|
|
IntType i;
|
|
inline bool operator==(const Entry &other) const {
|
|
return (parent == other.parent && i == other.i);
|
|
}
|
|
Entry() {}
|
|
Entry(const Entry &e) : parent(e.parent), i(e.i) {}
|
|
};
|
|
// Note: all Entry* pointers returned in function calls are
|
|
// owned by the repository itself, not by the caller!
|
|
|
|
// Interface guarantees empty string is NULL.
|
|
inline const Entry *EmptyString() { return NULL; }
|
|
|
|
// Returns string of "parent" with i appended. Pointer
|
|
// owned by repository
|
|
const Entry *Successor(const Entry *parent, IntType i) {
|
|
new_entry_->parent = parent;
|
|
new_entry_->i = i;
|
|
|
|
std::pair<typename SetType::iterator, bool> pr = set_.insert(new_entry_);
|
|
if (pr.second) { // Was successfully inserted (was not there). We need to
|
|
// replace the element we inserted, which resides on the
|
|
// stack, with one from the heap.
|
|
const Entry *ans = new_entry_;
|
|
new_entry_ = new Entry();
|
|
return ans;
|
|
} else { // Was not inserted because an equivalent Entry already
|
|
// existed.
|
|
return *pr.first;
|
|
}
|
|
}
|
|
|
|
const Entry *Concatenate(const Entry *a, const Entry *b) {
|
|
if (a == NULL)
|
|
return b;
|
|
else if (b == NULL)
|
|
return a;
|
|
std::vector<IntType> v;
|
|
ConvertToVector(b, &v);
|
|
const Entry *ans = a;
|
|
for (size_t i = 0; i < v.size(); i++) ans = Successor(ans, v[i]);
|
|
return ans;
|
|
}
|
|
const Entry *CommonPrefix(const Entry *a, const Entry *b) {
|
|
std::vector<IntType> a_vec, b_vec;
|
|
ConvertToVector(a, &a_vec);
|
|
ConvertToVector(b, &b_vec);
|
|
const Entry *ans = NULL;
|
|
for (size_t i = 0;
|
|
i < a_vec.size() && i < b_vec.size() && a_vec[i] == b_vec[i]; i++)
|
|
ans = Successor(ans, a_vec[i]);
|
|
return ans;
|
|
}
|
|
|
|
// removes any elements from b that are not part of
|
|
// a common prefix with a.
|
|
void ReduceToCommonPrefix(const Entry *a, std::vector<IntType> *b) {
|
|
size_t a_size = Size(a), b_size = b->size();
|
|
while (a_size > b_size) {
|
|
a = a->parent;
|
|
a_size--;
|
|
}
|
|
if (b_size > a_size) b_size = a_size;
|
|
typename std::vector<IntType>::iterator b_begin = b->begin();
|
|
while (a_size != 0) {
|
|
if (a->i != *(b_begin + a_size - 1)) b_size = a_size - 1;
|
|
a = a->parent;
|
|
a_size--;
|
|
}
|
|
if (b_size != b->size()) b->resize(b_size);
|
|
}
|
|
|
|
// removes the first n elements of a.
|
|
const Entry *RemovePrefix(const Entry *a, size_t n) {
|
|
if (n == 0) return a;
|
|
std::vector<IntType> a_vec;
|
|
ConvertToVector(a, &a_vec);
|
|
assert(a_vec.size() >= n);
|
|
const Entry *ans = NULL;
|
|
for (size_t i = n; i < a_vec.size(); i++) ans = Successor(ans, a_vec[i]);
|
|
return ans;
|
|
}
|
|
|
|
// Returns true if a is a prefix of b. If a is prefix of b,
|
|
// time taken is |b| - |a|. Else, time taken is |b|.
|
|
bool IsPrefixOf(const Entry *a, const Entry *b) const {
|
|
if (a == NULL) return true; // empty string prefix of all.
|
|
if (a == b) return true;
|
|
if (b == NULL) return false;
|
|
return IsPrefixOf(a, b->parent);
|
|
}
|
|
|
|
inline size_t Size(const Entry *entry) const {
|
|
size_t ans = 0;
|
|
while (entry != NULL) {
|
|
ans++;
|
|
entry = entry->parent;
|
|
}
|
|
return ans;
|
|
}
|
|
|
|
void ConvertToVector(const Entry *entry, std::vector<IntType> *out) const {
|
|
size_t length = Size(entry);
|
|
out->resize(length);
|
|
if (entry != NULL) {
|
|
typename std::vector<IntType>::reverse_iterator iter = out->rbegin();
|
|
while (entry != NULL) {
|
|
*iter = entry->i;
|
|
entry = entry->parent;
|
|
++iter;
|
|
}
|
|
}
|
|
}
|
|
|
|
const Entry *ConvertFromVector(const std::vector<IntType> &vec) {
|
|
const Entry *e = NULL;
|
|
for (size_t i = 0; i < vec.size(); i++) e = Successor(e, vec[i]);
|
|
return e;
|
|
}
|
|
|
|
LatticeStringRepository() { new_entry_ = new Entry; }
|
|
|
|
void Destroy() {
|
|
for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
|
|
++iter)
|
|
delete *iter;
|
|
SetType tmp;
|
|
tmp.swap(set_);
|
|
if (new_entry_) {
|
|
delete new_entry_;
|
|
new_entry_ = NULL;
|
|
}
|
|
}
|
|
|
|
// Rebuild will rebuild this object, guaranteeing only
|
|
// to preserve the Entry values that are in the vector pointed
|
|
// to (this list does not have to be unique). The point of
|
|
// this is to save memory.
|
|
void Rebuild(const std::vector<const Entry *> &to_keep) {
|
|
SetType tmp_set;
|
|
for (typename std::vector<const Entry *>::const_iterator iter =
|
|
to_keep.begin();
|
|
iter != to_keep.end(); ++iter)
|
|
RebuildHelper(*iter, &tmp_set);
|
|
// Now delete all elems not in tmp_set.
|
|
for (typename SetType::iterator iter = set_.begin(); iter != set_.end();
|
|
++iter) {
|
|
if (tmp_set.count(*iter) == 0)
|
|
delete (*iter); // delete the Entry; not needed.
|
|
}
|
|
set_.swap(tmp_set);
|
|
}
|
|
|
|
~LatticeStringRepository() { Destroy(); }
|
|
int32 MemSize() const {
|
|
return set_.size() * sizeof(Entry) * 2; // this is a lower bound
|
|
// on the size this structure might take.
|
|
}
|
|
|
|
private:
|
|
class EntryKey { // Hash function object.
|
|
public:
|
|
inline size_t operator()(const Entry *entry) const {
|
|
size_t prime = 49109;
|
|
return static_cast<size_t>(entry->i) +
|
|
prime * reinterpret_cast<size_t>(entry->parent);
|
|
}
|
|
};
|
|
class EntryEqual {
|
|
public:
|
|
inline bool operator()(const Entry *e1, const Entry *e2) const {
|
|
return (*e1 == *e2);
|
|
}
|
|
};
|
|
typedef std::unordered_set<const Entry *, EntryKey, EntryEqual> SetType;
|
|
|
|
void RebuildHelper(const Entry *to_add, SetType *tmp_set) {
|
|
while (true) {
|
|
if (to_add == NULL) return;
|
|
typename SetType::iterator iter = tmp_set->find(to_add);
|
|
if (iter == tmp_set->end()) { // not in tmp_set.
|
|
tmp_set->insert(to_add);
|
|
to_add = to_add->parent; // and loop.
|
|
} else {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeStringRepository);
|
|
Entry *new_entry_; // We always have a pre-allocated Entry ready to use,
|
|
// to avoid unnecessary news and deletes.
|
|
SetType set_;
|
|
};
|
|
|
|
// class LatticeDeterminizer is templated on the same types that
|
|
// CompactLatticeWeight is templated on: the base weight (Weight), typically
|
|
// LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
|
|
// IntType, typically int32, used for the output symbols in the compact
|
|
// representation of strings [note: the output symbols would usually be
|
|
// p.d.f. id's in the anticipated use of this code] It has a special requirement
|
|
// on the Weight type: that there should be a Compare function on the weights
|
|
// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1
|
|
// > w2. This requires that there be a total order on the weights.
|
|
|
|
template <class Weight, class IntType>
|
|
class LatticeDeterminizer {
|
|
public:
|
|
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
|
|
// correspondence between our states and the states in ofst. If destroy ==
|
|
// true, release memory as we go (but we cannot output again).
|
|
|
|
typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
|
|
typedef ArcTpl<CompactWeight>
|
|
CompactArc; // arc in compact, acceptor form of lattice
|
|
typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
|
|
|
|
// Output to standard FST with CompactWeightTpl<Weight> as its weight type
|
|
// (the weight stores the original output-symbol strings). If destroy ==
|
|
// true, release memory as we go (but we cannot output again).
|
|
void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
|
|
assert(determinized_);
|
|
typedef typename Arc::StateId StateId;
|
|
StateId nStates = static_cast<StateId>(output_arcs_.size());
|
|
if (destroy) FreeMostMemory();
|
|
ofst->DeleteStates();
|
|
ofst->SetStart(kNoStateId);
|
|
if (nStates == 0) {
|
|
return;
|
|
}
|
|
for (StateId s = 0; s < nStates; s++) {
|
|
OutputStateId news = ofst->AddState();
|
|
assert(news == s);
|
|
}
|
|
ofst->SetStart(0);
|
|
// now process transitions.
|
|
for (StateId this_state = 0; this_state < nStates; this_state++) {
|
|
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
|
|
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
|
|
end = this_vec.end();
|
|
|
|
for (; iter != end; ++iter) {
|
|
const TempArc &temp_arc(*iter);
|
|
CompactArc new_arc;
|
|
std::vector<Label> seq;
|
|
repository_.ConvertToVector(temp_arc.string, &seq);
|
|
CompactWeight weight(temp_arc.weight, seq);
|
|
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
|
|
ofst->SetFinal(this_state, weight);
|
|
} else { // is really an arc.
|
|
new_arc.nextstate = temp_arc.nextstate;
|
|
new_arc.ilabel = temp_arc.ilabel;
|
|
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
|
|
new_arc.weight = weight; // includes string and weight.
|
|
ofst->AddArc(this_state, new_arc);
|
|
}
|
|
}
|
|
// Free up memory. Do this inside the loop as ofst is also allocating
|
|
// memory
|
|
if (destroy) {
|
|
std::vector<TempArc> temp;
|
|
std::swap(temp, this_vec);
|
|
}
|
|
}
|
|
if (destroy) {
|
|
std::vector<std::vector<TempArc> > temp;
|
|
std::swap(temp, output_arcs_);
|
|
}
|
|
}
|
|
|
|
// Output to standard FST with Weight as its weight type. We will create
|
|
// extra states to handle sequences of symbols on the output. If destroy ==
|
|
// true, release memory as we go (but we cannot output again).
|
|
void Output(MutableFst<Arc> *ofst, bool destroy = true) {
|
|
// Outputs to standard fst.
|
|
OutputStateId nStates = static_cast<OutputStateId>(output_arcs_.size());
|
|
ofst->DeleteStates();
|
|
if (nStates == 0) {
|
|
ofst->SetStart(kNoStateId);
|
|
return;
|
|
}
|
|
if (destroy) FreeMostMemory();
|
|
// Add basic states-- but we will add extra ones to account for strings on
|
|
// output.
|
|
for (OutputStateId s = 0; s < nStates; s++) {
|
|
OutputStateId news = ofst->AddState();
|
|
assert(news == s);
|
|
}
|
|
ofst->SetStart(0);
|
|
for (OutputStateId this_state = 0; this_state < nStates; this_state++) {
|
|
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
|
|
|
|
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
|
|
end = this_vec.end();
|
|
for (; iter != end; ++iter) {
|
|
const TempArc &temp_arc(*iter);
|
|
std::vector<Label> seq;
|
|
repository_.ConvertToVector(temp_arc.string, &seq);
|
|
|
|
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
|
|
// Make a sequence of states going to a final state, with the strings
|
|
// as labels. Put the weight on the first arc.
|
|
OutputStateId cur_state = this_state;
|
|
for (size_t i = 0; i < seq.size(); i++) {
|
|
OutputStateId next_state = ofst->AddState();
|
|
Arc arc;
|
|
arc.nextstate = next_state;
|
|
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
|
|
arc.ilabel = 0; // epsilon.
|
|
arc.olabel = seq[i];
|
|
ofst->AddArc(cur_state, arc);
|
|
cur_state = next_state;
|
|
}
|
|
ofst->SetFinal(cur_state,
|
|
(seq.size() == 0 ? temp_arc.weight : Weight::One()));
|
|
} else { // Really an arc.
|
|
OutputStateId cur_state = this_state;
|
|
// Have to be careful with this integer comparison (i+1 < seq.size())
|
|
// because unsigned. i < seq.size()-1 could fail for zero-length
|
|
// sequences.
|
|
for (size_t i = 0; i + 1 < seq.size(); i++) {
|
|
// for all but the last element of seq, create new state.
|
|
OutputStateId next_state = ofst->AddState();
|
|
Arc arc;
|
|
arc.nextstate = next_state;
|
|
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
|
|
arc.ilabel = (i == 0 ? temp_arc.ilabel
|
|
: 0); // put ilabel on first element of seq.
|
|
arc.olabel = seq[i];
|
|
ofst->AddArc(cur_state, arc);
|
|
cur_state = next_state;
|
|
}
|
|
// Add the final arc in the sequence.
|
|
Arc arc;
|
|
arc.nextstate = temp_arc.nextstate;
|
|
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
|
|
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
|
|
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
|
|
ofst->AddArc(cur_state, arc);
|
|
}
|
|
}
|
|
// Free up memory. Do this inside the loop as ofst is also allocating
|
|
// memory
|
|
if (destroy) {
|
|
std::vector<TempArc> temp;
|
|
temp.swap(this_vec);
|
|
}
|
|
}
|
|
if (destroy) {
|
|
std::vector<std::vector<TempArc> > temp;
|
|
temp.swap(output_arcs_);
|
|
repository_.Destroy();
|
|
}
|
|
}
|
|
|
|
// Initializer. After initializing the object you will typically
|
|
// call Determinize() and then call one of the Output functions.
|
|
// Note: ifst.Copy() will generally do a
|
|
// shallow copy. We do it like this for memory safety, rather than
|
|
// keeping a reference or pointer to ifst_.
|
|
LatticeDeterminizer(const Fst<Arc> &ifst, DeterminizeLatticeOptions opts)
|
|
: num_arcs_(0),
|
|
num_elems_(0),
|
|
ifst_(ifst.Copy()),
|
|
opts_(opts),
|
|
equal_(opts_.delta),
|
|
determinized_(false),
|
|
minimal_hash_(3, hasher_, equal_),
|
|
initial_hash_(3, hasher_, equal_) {
|
|
KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
|
|
// work correctly otherwise.
|
|
}
|
|
|
|
// frees all except output_arcs_, which contains the important info
|
|
// we need to output the FST.
|
|
void FreeMostMemory() {
|
|
if (ifst_) {
|
|
delete ifst_;
|
|
ifst_ = NULL;
|
|
}
|
|
for (typename MinimalSubsetHash::iterator iter = minimal_hash_.begin();
|
|
iter != minimal_hash_.end(); ++iter)
|
|
delete iter->first;
|
|
{
|
|
MinimalSubsetHash tmp;
|
|
tmp.swap(minimal_hash_);
|
|
}
|
|
for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
|
|
iter != initial_hash_.end(); ++iter)
|
|
delete iter->first;
|
|
{
|
|
InitialSubsetHash tmp;
|
|
tmp.swap(initial_hash_);
|
|
}
|
|
{
|
|
std::vector<std::vector<Element> *> output_states_tmp;
|
|
output_states_tmp.swap(output_states_);
|
|
}
|
|
{
|
|
std::vector<char> tmp;
|
|
tmp.swap(isymbol_or_final_);
|
|
}
|
|
{
|
|
std::vector<OutputStateId> tmp;
|
|
tmp.swap(queue_);
|
|
}
|
|
{
|
|
std::vector<std::pair<Label, Element> > tmp;
|
|
tmp.swap(all_elems_tmp_);
|
|
}
|
|
}
|
|
|
|
~LatticeDeterminizer() {
|
|
FreeMostMemory(); // rest is deleted by destructors.
|
|
}
|
|
void RebuildRepository() { // rebuild the string repository,
|
|
// freeing stuff we don't need.. we call this when memory usage
|
|
// passes a supplied threshold. We need to accumulate all the
|
|
// strings we need the repository to "remember", then tell it
|
|
// to clean the repository.
|
|
std::vector<StringId> needed_strings;
|
|
for (size_t i = 0; i < output_arcs_.size(); i++)
|
|
for (size_t j = 0; j < output_arcs_[i].size(); j++)
|
|
needed_strings.push_back(output_arcs_[i][j].string);
|
|
|
|
// the following loop covers strings present in minimal_hash_
|
|
// which are also accessible via output_states_.
|
|
for (size_t i = 0; i < output_states_.size(); i++)
|
|
for (size_t j = 0; j < output_states_[i]->size(); j++)
|
|
needed_strings.push_back((*(output_states_[i]))[j].string);
|
|
|
|
// the following loop covers strings present in initial_hash_.
|
|
for (typename InitialSubsetHash::const_iterator iter =
|
|
initial_hash_.begin();
|
|
iter != initial_hash_.end(); ++iter) {
|
|
const std::vector<Element> &vec = *(iter->first);
|
|
Element elem = iter->second;
|
|
for (size_t i = 0; i < vec.size(); i++)
|
|
needed_strings.push_back(vec[i].string);
|
|
needed_strings.push_back(elem.string);
|
|
}
|
|
|
|
std::sort(needed_strings.begin(), needed_strings.end());
|
|
needed_strings.erase(
|
|
std::unique(needed_strings.begin(), needed_strings.end()),
|
|
needed_strings.end()); // uniq the strings.
|
|
repository_.Rebuild(needed_strings);
|
|
}
|
|
|
|
bool CheckMemoryUsage() {
|
|
int32 repo_size = repository_.MemSize(),
|
|
arcs_size = num_arcs_ * sizeof(TempArc),
|
|
elems_size = num_elems_ * sizeof(Element),
|
|
total_size = repo_size + arcs_size + elems_size;
|
|
if (opts_.max_mem > 0 &&
|
|
total_size > opts_.max_mem) { // We passed the memory threshold.
|
|
// This is usually due to the repository getting large, so we
|
|
// clean this out.
|
|
RebuildRepository();
|
|
int32 new_repo_size = repository_.MemSize(),
|
|
new_total_size = new_repo_size + arcs_size + elems_size;
|
|
|
|
KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository "
|
|
"shrank from "
|
|
<< repo_size << " to " << new_repo_size
|
|
<< " bytes (approximately)";
|
|
|
|
if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
|
|
// Rebuilding didn't help enough-- we need a margin to stop
|
|
// having to rebuild too often.
|
|
KALDI_WARN << "Failure in determinize-lattice: size exceeds maximum "
|
|
<< opts_.max_mem << " bytes; (repo,arcs,elems) = ("
|
|
<< repo_size << "," << arcs_size << "," << elems_size
|
|
<< "), after rebuilding, repo size was " << new_repo_size;
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Returns true on success. Can fail for out-of-memory
|
|
// or max-states related reasons.
|
|
bool Determinize(bool *debug_ptr) {
|
|
assert(!determinized_);
|
|
// This determinizes the input fst but leaves it in the "special format"
|
|
// in "output_arcs_". Must be called after Initialize(). To get the
|
|
// output, call one of the Output routines.
|
|
try {
|
|
InitializeDeterminization(); // some start-up tasks.
|
|
while (!queue_.empty()) {
|
|
OutputStateId out_state = queue_.back();
|
|
queue_.pop_back();
|
|
ProcessState(out_state);
|
|
if (debug_ptr && *debug_ptr) Debug(); // will exit.
|
|
if (!CheckMemoryUsage()) return false;
|
|
}
|
|
return (determinized_ = true);
|
|
} catch (const std::bad_alloc &) {
|
|
int32 repo_size = repository_.MemSize(),
|
|
arcs_size = num_arcs_ * sizeof(TempArc),
|
|
elems_size = num_elems_ * sizeof(Element),
|
|
total_size = repo_size + arcs_size + elems_size;
|
|
KALDI_WARN
|
|
<< "Memory allocation error doing lattice determinization; using "
|
|
<< total_size << " bytes (max = " << opts_.max_mem
|
|
<< " (repo,arcs,elems) = (" << repo_size << "," << arcs_size << ","
|
|
<< elems_size << ")";
|
|
return (determinized_ = false);
|
|
} catch (const std::runtime_error &) {
|
|
KALDI_WARN << "Caught exception doing lattice determinization";
|
|
return (determinized_ = false);
|
|
}
|
|
}
|
|
|
|
private:
|
|
typedef typename Arc::Label Label;
|
|
typedef typename Arc::StateId
|
|
StateId; // use this when we don't know if it's input or output.
|
|
typedef typename Arc::StateId InputStateId; // state in the input FST.
|
|
typedef typename Arc::StateId OutputStateId; // same as above but distinguish
|
|
// states in output Fst.
|
|
|
|
typedef LatticeStringRepository<IntType> StringRepositoryType;
|
|
typedef const typename StringRepositoryType::Entry *StringId;
|
|
|
|
// Element of a subset [of original states]
|
|
struct Element {
|
|
StateId state; // use StateId as this is usually InputStateId but in one
|
|
// case OutputStateId.
|
|
StringId string;
|
|
Weight weight;
|
|
bool operator!=(const Element &other) const {
|
|
return (state != other.state || string != other.string ||
|
|
weight != other.weight);
|
|
}
|
|
// This operator is only intended to support sorting in EpsilonClosure()
|
|
bool operator<(const Element &other) const { return state < other.state; }
|
|
};
|
|
|
|
// Arcs in the format we temporarily create in this class (a representation,
|
|
// essentially of a Gallic Fst).
|
|
struct TempArc {
|
|
Label ilabel;
|
|
StringId string; // Look it up in the StringRepository, it's a sequence of
|
|
// Labels.
|
|
OutputStateId nextstate; // or kNoState for final weights.
|
|
Weight weight;
|
|
};
|
|
|
|
// Hashing function used in hash of subsets.
|
|
// A subset is a pointer to vector<Element>.
|
|
// The Elements are in sorted order on state id, and without repeated states.
|
|
// Because the order of Elements is fixed, we can use a hashing function that
|
|
// is order-dependent. However the weights are not included in the hashing
|
|
// function-- we hash subsets that differ only in weight to the same key. This
|
|
// is not optimal in terms of the O(N) performance but typically if we have a
|
|
// lot of determinized states that differ only in weight then the input
|
|
// probably was pathological in some way, or even non-determinizable.
|
|
// We don't quantize the weights, in order to avoid inexactness in simple
|
|
// cases.
|
|
// Instead we apply the delta when comparing subsets for equality, and allow a
|
|
// small difference.
|
|
|
|
class SubsetKey {
|
|
public:
|
|
size_t operator()(const std::vector<Element> *subset)
|
|
const { // hashes only the state and string.
|
|
size_t hash = 0, factor = 1;
|
|
for (typename std::vector<Element>::const_iterator iter = subset->begin();
|
|
iter != subset->end(); ++iter) {
|
|
hash *= factor;
|
|
hash += iter->state + reinterpret_cast<size_t>(iter->string);
|
|
factor *= 23531; // these numbers are primes.
|
|
}
|
|
return hash;
|
|
}
|
|
};
|
|
|
|
// This is the equality operator on subsets. It checks for exact match on
|
|
// state-id and string, and approximate match on weights.
|
|
class SubsetEqual {
|
|
public:
|
|
bool operator()(const std::vector<Element> *s1,
|
|
const std::vector<Element> *s2) const {
|
|
size_t sz = s1->size();
|
|
assert(sz >= 0);
|
|
if (sz != s2->size()) return false;
|
|
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
|
|
iter1_end = s1->end(),
|
|
iter2 = s2->begin();
|
|
for (; iter1 < iter1_end; ++iter1, ++iter2) {
|
|
if (iter1->state != iter2->state || iter1->string != iter2->string ||
|
|
!ApproxEqual(iter1->weight, iter2->weight, delta_))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
float delta_;
|
|
explicit SubsetEqual(float delta) : delta_(delta) {}
|
|
SubsetEqual() : delta_(kDelta) {}
|
|
};
|
|
|
|
// Operator that says whether two Elements have the same states.
|
|
// Used only for debug.
|
|
class SubsetEqualStates {
|
|
public:
|
|
bool operator()(const std::vector<Element> *s1,
|
|
const std::vector<Element> *s2) const {
|
|
size_t sz = s1->size();
|
|
assert(sz >= 0);
|
|
if (sz != s2->size()) return false;
|
|
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
|
|
iter1_end = s1->end(),
|
|
iter2 = s2->begin();
|
|
for (; iter1 < iter1_end; ++iter1, ++iter2) {
|
|
if (iter1->state != iter2->state) return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
// Define the hash type we use to map subsets (in minimal
|
|
// representation) to OutputStateId.
|
|
typedef std::unordered_map<const std::vector<Element> *, OutputStateId,
|
|
SubsetKey, SubsetEqual>
|
|
MinimalSubsetHash;
|
|
|
|
// Define the hash type we use to map subsets (in initial
|
|
// representation) to OutputStateId, together with an
|
|
// extra weight. [note: we interpret the Element.state in here
|
|
// as an OutputStateId even though it's declared as InputStateId;
|
|
// these types are the same anyway].
|
|
typedef std::unordered_map<const std::vector<Element> *, Element, SubsetKey,
|
|
SubsetEqual>
|
|
InitialSubsetHash;
|
|
|
|
// converts the representation of the subset from canonical (all states) to
|
|
// minimal (only states with output symbols on arcs leaving them, and final
|
|
// states). Output is not necessarily normalized, even if input_subset was.
|
|
void ConvertToMinimal(std::vector<Element> *subset) {
|
|
assert(!subset->empty());
|
|
typename std::vector<Element>::iterator cur_in = subset->begin(),
|
|
cur_out = subset->begin(),
|
|
end = subset->end();
|
|
while (cur_in != end) {
|
|
if (IsIsymbolOrFinal(cur_in->state)) { // keep it...
|
|
*cur_out = *cur_in;
|
|
cur_out++;
|
|
}
|
|
cur_in++;
|
|
}
|
|
subset->resize(cur_out - subset->begin());
|
|
}
|
|
|
|
// Takes a minimal, normalized subset, and converts it to an OutputStateId.
|
|
// Involves a hash lookup, and possibly adding a new OutputStateId.
|
|
// If it creates a new OutputStateId, it adds it to the queue.
|
|
OutputStateId MinimalToStateId(const std::vector<Element> &subset) {
|
|
typename MinimalSubsetHash::const_iterator iter =
|
|
minimal_hash_.find(&subset);
|
|
if (iter != minimal_hash_.end()) // Found a matching subset.
|
|
return iter->second;
|
|
OutputStateId ans = static_cast<OutputStateId>(output_arcs_.size());
|
|
std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
|
|
output_states_.push_back(subset_ptr);
|
|
num_elems_ += subset_ptr->size();
|
|
output_arcs_.push_back(std::vector<TempArc>());
|
|
minimal_hash_[subset_ptr] = ans;
|
|
queue_.push_back(ans);
|
|
return ans;
|
|
}
|
|
|
|
// Given a normalized initial subset of elements (i.e. before epsilon
|
|
// closure), compute the corresponding output-state.
|
|
OutputStateId InitialToStateId(const std::vector<Element> &subset_in,
|
|
Weight *remaining_weight,
|
|
StringId *common_prefix) {
|
|
typename InitialSubsetHash::const_iterator iter =
|
|
initial_hash_.find(&subset_in);
|
|
if (iter != initial_hash_.end()) { // Found a matching subset.
|
|
const Element &elem = iter->second;
|
|
*remaining_weight = elem.weight;
|
|
*common_prefix = elem.string;
|
|
if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
|
|
return elem.state;
|
|
}
|
|
// else no matching subset-- have to work it out.
|
|
std::vector<Element> subset(subset_in);
|
|
// Follow through epsilons. Will add no duplicate states. note: after
|
|
// EpsilonClosure, it is the same as "canonical" subset, except not
|
|
// normalized (actually we never compute the normalized canonical subset,
|
|
// only the normalized minimal one).
|
|
EpsilonClosure(&subset); // follow epsilons.
|
|
ConvertToMinimal(&subset); // remove all but emitting and final states.
|
|
|
|
Element elem; // will be used to store remaining weight and string, and
|
|
// OutputStateId, in initial_hash_;
|
|
NormalizeSubset(&subset, &elem.weight,
|
|
&elem.string); // normalize subset; put
|
|
// common string and weight in "elem". The subset is now a minimal,
|
|
// normalized subset.
|
|
|
|
OutputStateId ans = MinimalToStateId(subset);
|
|
*remaining_weight = elem.weight;
|
|
*common_prefix = elem.string;
|
|
if (elem.weight == Weight::Zero()) KALDI_WARN << "Zero weight!"; // TEMP
|
|
|
|
// Before returning "ans", add the initial subset to the hash,
|
|
// so that we can bypass the epsilon-closure etc., next time
|
|
// we process the same initial subset.
|
|
std::vector<Element> *initial_subset_ptr =
|
|
new std::vector<Element>(subset_in);
|
|
elem.state = ans;
|
|
initial_hash_[initial_subset_ptr] = elem;
|
|
num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
|
|
return ans;
|
|
}
|
|
|
|
// returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
|
|
// to the ordering we defined on strings for the CompactLatticeWeightTpl.
|
|
// see function
|
|
// inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
|
|
// const CompactLatticeWeightTpl<WeightType,IntType> &w2)
|
|
// in lattice-weight.h.
|
|
// this is the same as that, but optimized for our data structures.
|
|
inline int Compare(const Weight &a_w, StringId a_str, const Weight &b_w,
|
|
StringId b_str) const {
|
|
int weight_comp = fst::Compare(a_w, b_w);
|
|
if (weight_comp != 0) return weight_comp;
|
|
// now comparing strings.
|
|
if (a_str == b_str) return 0;
|
|
std::vector<IntType> a_vec, b_vec;
|
|
repository_.ConvertToVector(a_str, &a_vec);
|
|
repository_.ConvertToVector(b_str, &b_vec);
|
|
// First compare their lengths.
|
|
int a_len = a_vec.size(), b_len = b_vec.size();
|
|
// use opposite order on the string lengths (c.f. Compare in
|
|
// lattice-weight.h)
|
|
if (a_len > b_len)
|
|
return -1;
|
|
else if (a_len < b_len)
|
|
return 1;
|
|
for (int i = 0; i < a_len; i++) {
|
|
if (a_vec[i] < b_vec[i])
|
|
return -1;
|
|
else if (a_vec[i] > b_vec[i])
|
|
return 1;
|
|
}
|
|
assert(
|
|
0); // because we checked if a_str == b_str above, shouldn't reach here
|
|
return 0;
|
|
}
|
|
|
|
// This function computes epsilon closure of subset of states by following
|
|
// epsilon links. Called by InitialToStateId and Initialize. Has no side
|
|
// effects except on the string repository. The "output_subset" is not
|
|
// necessarily normalized (in the sense of there being no common substring),
|
|
// unless input_subset was.
|
|
void EpsilonClosure(std::vector<Element> *subset) {
|
|
// at input, subset must have only one example of each StateId. [will still
|
|
// be so at output]. This function follows input-epsilons, and augments the
|
|
// subset accordingly.
|
|
|
|
std::deque<Element> queue;
|
|
std::unordered_map<InputStateId, Element> cur_subset;
|
|
typedef
|
|
typename std::unordered_map<InputStateId, Element>::iterator MapIter;
|
|
typedef typename std::vector<Element>::const_iterator VecIter;
|
|
|
|
for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
|
|
queue.push_back(*iter);
|
|
cur_subset[iter->state] = *iter;
|
|
}
|
|
|
|
// find whether input fst is known to be sorted on input label.
|
|
bool sorted =
|
|
((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
|
|
bool replaced_elems = false; // relates to an optimization, see below.
|
|
int counter =
|
|
0; // stops infinite loops here for non-lattice-determinizable input;
|
|
// useful in testing.
|
|
while (queue.size() != 0) {
|
|
Element elem = queue.front();
|
|
queue.pop_front();
|
|
|
|
// The next if-statement is a kind of optimization. It's to prevent us
|
|
// unnecessarily repeating the processing of a state. "cur_subset" always
|
|
// contains only one Element with a particular state. The issue is that
|
|
// whenever we modify the Element corresponding to that state in
|
|
// "cur_subset", both the new (optimal) and old (less-optimal) Element
|
|
// will still be in "queue". The next if-statement stops us from wasting
|
|
// compute by processing the old Element.
|
|
if (replaced_elems && cur_subset[elem.state] != elem) continue;
|
|
if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
|
|
KALDI_ERR << "Lattice determinization aborted since looped more than "
|
|
<< opts_.max_loop << " times during epsilon closure";
|
|
}
|
|
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
|
|
aiter.Next()) {
|
|
const Arc &arc = aiter.Value();
|
|
if (sorted && arc.ilabel != 0)
|
|
break; // Break from the loop: due to sorting there will be no
|
|
// more transitions with epsilons as input labels.
|
|
if (arc.ilabel == 0 &&
|
|
arc.weight != Weight::Zero()) { // Epsilon transition.
|
|
Element next_elem;
|
|
next_elem.state = arc.nextstate;
|
|
next_elem.weight = Times(elem.weight, arc.weight);
|
|
// now must append strings
|
|
if (arc.olabel == 0)
|
|
next_elem.string = elem.string;
|
|
else
|
|
next_elem.string = repository_.Successor(elem.string, arc.olabel);
|
|
|
|
MapIter iter = cur_subset.find(next_elem.state);
|
|
if (iter == cur_subset.end()) {
|
|
// was no such StateId: insert and add to queue.
|
|
cur_subset[next_elem.state] = next_elem;
|
|
queue.push_back(next_elem);
|
|
} else {
|
|
// was not inserted because one already there. In normal
|
|
// determinization we'd add the weights. Here, we find which one
|
|
// has the better weight, and keep its corresponding string.
|
|
int comp = Compare(next_elem.weight, next_elem.string,
|
|
iter->second.weight, iter->second.string);
|
|
if (comp ==
|
|
1) { // next_elem is better, so use its (weight, string)
|
|
iter->second.string = next_elem.string;
|
|
iter->second.weight = next_elem.weight;
|
|
queue.push_back(next_elem);
|
|
replaced_elems = true;
|
|
}
|
|
// else it is the same or worse, so use original one.
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
{ // copy cur_subset to subset.
|
|
subset->clear();
|
|
subset->reserve(cur_subset.size());
|
|
MapIter iter = cur_subset.begin(), end = cur_subset.end();
|
|
for (; iter != end; ++iter) subset->push_back(iter->second);
|
|
// sort by state ID, because the subset hash function is
|
|
// order-dependent(see SubsetKey)
|
|
std::sort(subset->begin(), subset->end());
|
|
}
|
|
}
|
|
|
|
// This function works out the final-weight of the determinized state.
|
|
// called by ProcessSubset.
|
|
// Has no side effects except on the variable repository_, and output_arcs_.
|
|
|
|
void ProcessFinal(OutputStateId output_state) {
|
|
const std::vector<Element> &minimal_subset =
|
|
*(output_states_[output_state]);
|
|
// processes final-weights for this subset.
|
|
|
|
// minimal_subset may be empty if the graphs is not connected/trimmed, I
|
|
// think, do don't check that it's nonempty.
|
|
bool is_final = false;
|
|
StringId final_string = NULL; // = NULL to keep compiler happy.
|
|
Weight final_weight = Weight::Zero();
|
|
typename std::vector<Element>::const_iterator iter = minimal_subset.begin(),
|
|
end = minimal_subset.end();
|
|
for (; iter != end; ++iter) {
|
|
const Element &elem = *iter;
|
|
Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
|
|
StringId this_final_string = elem.string;
|
|
if (this_final_weight != Weight::Zero() &&
|
|
(!is_final || Compare(this_final_weight, this_final_string,
|
|
final_weight, final_string) == 1)) { // the new
|
|
// (weight, string) pair is more in semiring than our current
|
|
// one.
|
|
is_final = true;
|
|
final_weight = this_final_weight;
|
|
final_string = this_final_string;
|
|
}
|
|
}
|
|
if (is_final) {
|
|
// store final weights in TempArc structure, just like a transition.
|
|
TempArc temp_arc;
|
|
temp_arc.ilabel = 0;
|
|
temp_arc.nextstate =
|
|
kNoStateId; // special marker meaning "final weight".
|
|
temp_arc.string = final_string;
|
|
temp_arc.weight = final_weight;
|
|
output_arcs_[output_state].push_back(temp_arc);
|
|
num_arcs_++;
|
|
}
|
|
}
|
|
|
|
// NormalizeSubset normalizes the subset "elems" by
|
|
// removing any common string prefix (putting it in common_str),
|
|
// and dividing by the total weight (putting it in tot_weight).
|
|
void NormalizeSubset(std::vector<Element> *elems, Weight *tot_weight,
|
|
StringId *common_str) {
|
|
if (elems->empty()) { // just set common_str, tot_weight
|
|
KALDI_WARN << "[empty subset]"; // TEMP
|
|
// to defaults and return...
|
|
*common_str = repository_.EmptyString();
|
|
*tot_weight = Weight::Zero();
|
|
return;
|
|
}
|
|
size_t size = elems->size();
|
|
std::vector<IntType> common_prefix;
|
|
repository_.ConvertToVector((*elems)[0].string, &common_prefix);
|
|
Weight weight = (*elems)[0].weight;
|
|
for (size_t i = 1; i < size; i++) {
|
|
weight = Plus(weight, (*elems)[i].weight);
|
|
repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
|
|
}
|
|
assert(weight != Weight::Zero()); // we made sure to ignore arcs with zero
|
|
// weights on them, so we shouldn't have zero here.
|
|
size_t prefix_len = common_prefix.size();
|
|
for (size_t i = 0; i < size; i++) {
|
|
(*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
|
|
(*elems)[i].string =
|
|
repository_.RemovePrefix((*elems)[i].string, prefix_len);
|
|
}
|
|
*common_str = repository_.ConvertFromVector(common_prefix);
|
|
*tot_weight = weight;
|
|
}
|
|
|
|
// Take a subset of Elements that is sorted on state, and
|
|
// merge any Elements that have the same state (taking the best
|
|
// (weight, string) pair in the semiring).
|
|
void MakeSubsetUnique(std::vector<Element> *subset) {
|
|
typedef typename std::vector<Element>::iterator IterType;
|
|
|
|
// This assert is designed to fail (usually) if the subset is not sorted on
|
|
// state.
|
|
assert(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
|
|
|
|
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
|
|
size_t num_out = 0;
|
|
// Merge elements with same state-id
|
|
while (cur_in != end) { // while we have more elements to process.
|
|
// At this point, cur_out points to location of next place we want to put
|
|
// an element, cur_in points to location of next element we want to
|
|
// process.
|
|
if (cur_in != cur_out) *cur_out = *cur_in;
|
|
cur_in++;
|
|
while (cur_in != end && cur_in->state == cur_out->state) {
|
|
if (Compare(cur_in->weight, cur_in->string, cur_out->weight,
|
|
cur_out->string) == 1) {
|
|
// if *cur_in > *cur_out in semiring, then take *cur_in.
|
|
cur_out->string = cur_in->string;
|
|
cur_out->weight = cur_in->weight;
|
|
}
|
|
cur_in++;
|
|
}
|
|
cur_out++;
|
|
num_out++;
|
|
}
|
|
subset->resize(num_out);
|
|
}
|
|
|
|
// ProcessTransition is called from "ProcessTransitions". Broken out for
|
|
// clarity. Processes a transition from state "state". The set of Elements
|
|
// represents a set of next-states with associated weights and strings, each
|
|
// one arising from an arc from some state in a determinized-state; the
|
|
// next-states are not necessarily unique (i.e. there may be >1 entry
|
|
// associated with each), and any such sets of Elements have to be merged
|
|
// within this routine (we take the [weight, string] pair that's better in the
|
|
// semiring).
|
|
void ProcessTransition(OutputStateId state, Label ilabel,
|
|
std::vector<Element> *subset) {
|
|
MakeSubsetUnique(subset); // remove duplicates with the same state.
|
|
|
|
StringId common_str;
|
|
Weight tot_weight;
|
|
NormalizeSubset(subset, &tot_weight, &common_str);
|
|
|
|
OutputStateId nextstate;
|
|
{
|
|
Weight next_tot_weight;
|
|
StringId next_common_str;
|
|
nextstate = InitialToStateId(*subset, &next_tot_weight, &next_common_str);
|
|
common_str = repository_.Concatenate(common_str, next_common_str);
|
|
tot_weight = Times(tot_weight, next_tot_weight);
|
|
}
|
|
|
|
// Now add an arc to the next state (would have been created if necessary by
|
|
// InitialToStateId).
|
|
TempArc temp_arc;
|
|
temp_arc.ilabel = ilabel;
|
|
temp_arc.nextstate = nextstate;
|
|
temp_arc.string = common_str;
|
|
temp_arc.weight = tot_weight;
|
|
output_arcs_[state].push_back(temp_arc); // record the arc.
|
|
num_arcs_++;
|
|
}
|
|
|
|
// "less than" operator for pair<Label, Element>. Used in
|
|
// ProcessTransitions. Lexicographical order, which only compares the state
|
|
// when ordering the "Element" member of the pair.
|
|
|
|
class PairComparator {
|
|
public:
|
|
inline bool operator()(const std::pair<Label, Element> &p1,
|
|
const std::pair<Label, Element> &p2) {
|
|
if (p1.first < p2.first) {
|
|
return true;
|
|
} else if (p1.first > p2.first) {
|
|
return false;
|
|
} else {
|
|
return p1.second.state < p2.second.state;
|
|
}
|
|
}
|
|
};
|
|
|
|
// ProcessTransitions processes emitting transitions (transitions
|
|
// with ilabels) out of this subset of states.
|
|
// Does not consider final states. Breaks the emitting transitions up by
|
|
// ilabel, and creates a new transition in the determinized FST for each
|
|
// unique ilabel. Does this by creating a big vector of pairs <Label, Element>
|
|
// and then sorting them using a lexicographical ordering, and calling
|
|
// ProcessTransition for each range with the same ilabel. Side effects on
|
|
// repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
|
|
|
|
void ProcessTransitions(OutputStateId output_state) {
|
|
const std::vector<Element> &minimal_subset =
|
|
*(output_states_[output_state]);
|
|
// it's possible that minimal_subset could be empty if there are
|
|
// unreachable parts of the graph, so don't check that it's nonempty.
|
|
std::vector<std::pair<Label, Element> > &all_elems(
|
|
all_elems_tmp_); // use class member
|
|
// to avoid memory allocation/deallocation.
|
|
{
|
|
// Push back into "all_elems", elements corresponding to all
|
|
// non-epsilon-input transitions out of all states in "minimal_subset".
|
|
typename std::vector<Element>::const_iterator iter =
|
|
minimal_subset.begin(),
|
|
end = minimal_subset.end();
|
|
for (; iter != end; ++iter) {
|
|
const Element &elem = *iter;
|
|
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
|
|
aiter.Next()) {
|
|
const Arc &arc = aiter.Value();
|
|
if (arc.ilabel != 0 &&
|
|
arc.weight != Weight::Zero()) { // Non-epsilon transition --
|
|
// ignore epsilons here.
|
|
std::pair<Label, Element> this_pr;
|
|
this_pr.first = arc.ilabel;
|
|
Element &next_elem(this_pr.second);
|
|
next_elem.state = arc.nextstate;
|
|
next_elem.weight = Times(elem.weight, arc.weight);
|
|
if (arc.olabel == 0) // output epsilon
|
|
next_elem.string = elem.string;
|
|
else
|
|
next_elem.string = repository_.Successor(elem.string, arc.olabel);
|
|
all_elems.push_back(this_pr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
PairComparator pc;
|
|
std::sort(all_elems.begin(), all_elems.end(), pc);
|
|
// now sorted first on input label, then on state.
|
|
typedef typename std::vector<std::pair<Label, Element> >::const_iterator
|
|
PairIter;
|
|
PairIter cur = all_elems.begin(), end = all_elems.end();
|
|
std::vector<Element> this_subset;
|
|
while (cur != end) {
|
|
// Process ranges that share the same input symbol.
|
|
Label ilabel = cur->first;
|
|
this_subset.clear();
|
|
while (cur != end && cur->first == ilabel) {
|
|
this_subset.push_back(cur->second);
|
|
cur++;
|
|
}
|
|
// We now have a subset for this ilabel.
|
|
assert(!this_subset.empty()); // temp.
|
|
ProcessTransition(output_state, ilabel, &this_subset);
|
|
}
|
|
all_elems.clear(); // as it's a class variable-- want it to stay
|
|
// emtpy.
|
|
}
|
|
|
|
// ProcessState does the processing of a determinized state, i.e. it creates
|
|
// transitions out of it and the final-probability if any.
|
|
void ProcessState(OutputStateId output_state) {
|
|
ProcessFinal(output_state);
|
|
ProcessTransitions(output_state);
|
|
}
|
|
|
|
void Debug() { // this function called if you send a signal
|
|
// SIGUSR1 to the process (and it's caught by the handler in
|
|
// fstdeterminizestar). It prints out some traceback
|
|
// info and exits.
|
|
|
|
KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
|
|
// free up memory from the hash as we need a little memory
|
|
{
|
|
MinimalSubsetHash hash_tmp;
|
|
hash_tmp.swap(minimal_hash_);
|
|
}
|
|
|
|
if (output_arcs_.size() <= 2) {
|
|
KALDI_ERR << "Nothing to trace back";
|
|
}
|
|
size_t max_state = output_arcs_.size() - 2; // Don't take the last
|
|
// one as we might be halfway into constructing it.
|
|
|
|
std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId);
|
|
for (size_t i = 0; i < max_state; i++) {
|
|
for (size_t j = 0; j < output_arcs_[i].size(); j++) {
|
|
OutputStateId nextstate = output_arcs_[i][j].nextstate;
|
|
// Always find an earlier-numbered predecessor; this
|
|
// is always possible because of the way the algorithm
|
|
// works.
|
|
if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
|
|
}
|
|
}
|
|
std::vector<std::pair<Label, StringId> > traceback;
|
|
// 'traceback' is a pair of (ilabel, olabel-seq).
|
|
OutputStateId cur_state = max_state; // A recently constructed state.
|
|
|
|
while (cur_state != 0 && cur_state != kNoStateId) {
|
|
OutputStateId last_state = predecessor[cur_state];
|
|
std::pair<Label, StringId> p;
|
|
size_t i;
|
|
for (i = 0; i < output_arcs_[last_state].size(); i++) {
|
|
if (output_arcs_[last_state][i].nextstate == cur_state) {
|
|
p.first = output_arcs_[last_state][i].ilabel;
|
|
p.second = output_arcs_[last_state][i].string;
|
|
traceback.push_back(p);
|
|
break;
|
|
}
|
|
}
|
|
KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
|
|
cur_state = last_state;
|
|
}
|
|
if (cur_state == kNoStateId)
|
|
KALDI_WARN << "Traceback did not reach start state "
|
|
<< "(possibly debug-code error)";
|
|
|
|
std::stringstream ss;
|
|
ss << "Traceback follows in format "
|
|
<< "ilabel (olabel olabel) ilabel (olabel) ... :";
|
|
for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
|
|
ss << ' ' << traceback[i].first << " ( ";
|
|
std::vector<Label> seq;
|
|
repository_.ConvertToVector(traceback[i].second, &seq);
|
|
for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
|
|
ss << ')';
|
|
}
|
|
KALDI_ERR << ss.str();
|
|
}
|
|
|
|
bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
|
|
// of the input FST either is final or has an osymbol on an arc out of it.
|
|
// Uses the vector isymbol_or_final_ as a cache for this info.
|
|
assert(state >= 0);
|
|
if (isymbol_or_final_.size() <= state)
|
|
isymbol_or_final_.resize(state + 1, static_cast<char>(OSF_UNKNOWN));
|
|
if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
|
|
return false;
|
|
else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
|
|
return true;
|
|
// else work it out...
|
|
isymbol_or_final_[state] = static_cast<char>(OSF_NO);
|
|
if (ifst_->Final(state) != Weight::Zero())
|
|
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
|
|
for (ArcIterator<Fst<Arc> > aiter(*ifst_, state); !aiter.Done();
|
|
aiter.Next()) {
|
|
const Arc &arc = aiter.Value();
|
|
if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
|
|
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
|
|
return true;
|
|
}
|
|
}
|
|
return IsIsymbolOrFinal(state); // will only recurse once.
|
|
}
|
|
|
|
void InitializeDeterminization() {
|
|
if (ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
|
|
// states in ifst_, it might be a bit more efficient
|
|
// to pre-size the hashes so we're not constantly rebuilding them.
|
|
#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
|
|
StateId num_states =
|
|
down_cast<const ExpandedFst<Arc> *, const Fst<Arc> >(ifst_)
|
|
->NumStates();
|
|
minimal_hash_.rehash(num_states / 2 + 3);
|
|
initial_hash_.rehash(num_states / 2 + 3);
|
|
#endif
|
|
}
|
|
InputStateId start_id = ifst_->Start();
|
|
if (start_id != kNoStateId) {
|
|
/* Insert determinized-state corresponding to the start state into hash
|
|
and queue. Unlike all the other states, we don't "normalize" the
|
|
representation of this determinized-state before we put it into
|
|
minimal_hash_. This is actually what we want, as otherwise we'd have
|
|
problems dealing with any extra weight and string and might have to
|
|
create a "super-initial" state which would make the output
|
|
nondeterministic. Normalization is only needed to make the
|
|
determinized output more minimal anyway, it's not needed for
|
|
correctness. Note, we don't put anything in the initial_hash_. The
|
|
initial_hash_ is only a lookaside buffer anyway, so this isn't a
|
|
problem-- it will get populated later if it needs to be.
|
|
*/
|
|
Element elem;
|
|
elem.state = start_id;
|
|
elem.weight = Weight::One();
|
|
elem.string = repository_.EmptyString(); // Id of empty sequence.
|
|
std::vector<Element> subset;
|
|
subset.push_back(elem);
|
|
EpsilonClosure(&subset); // follow through epsilon-inputs links
|
|
ConvertToMinimal(&subset); // remove all but final states and
|
|
// states with input-labels on arcs out of them.
|
|
std::vector<Element> *subset_ptr = new std::vector<Element>(subset);
|
|
assert(output_arcs_.empty() && output_states_.empty());
|
|
// add the new state...
|
|
output_states_.push_back(subset_ptr);
|
|
output_arcs_.push_back(std::vector<TempArc>());
|
|
OutputStateId initial_state = 0;
|
|
minimal_hash_[subset_ptr] = initial_state;
|
|
queue_.push_back(initial_state);
|
|
}
|
|
}
|
|
|
|
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizer);
|
|
|
|
std::vector<std::vector<Element> *>
|
|
output_states_; // maps from output state to
|
|
// minimal representation [normalized].
|
|
// View pointers as owned in
|
|
// minimal_hash_.
|
|
std::vector<std::vector<TempArc> >
|
|
output_arcs_; // essentially an FST in our format.
|
|
|
|
int num_arcs_; // keep track of memory usage: number of arcs in output_arcs_
|
|
int num_elems_; // keep track of memory usage: number of elems in
|
|
// output_states_
|
|
|
|
const Fst<Arc> *ifst_;
|
|
DeterminizeLatticeOptions opts_;
|
|
SubsetKey hasher_; // object that computes keys-- has no data members.
|
|
SubsetEqual
|
|
equal_; // object that compares subsets-- only data member is delta_.
|
|
bool determinized_; // set to true when user called Determinize(); used to
|
|
// make
|
|
// sure this object is used correctly.
|
|
MinimalSubsetHash
|
|
minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
|
|
// representation" (only include final and states and
|
|
// states with nonzero ilabel on arc out of them. Owns
|
|
// the pointers in its keys.
|
|
InitialSubsetHash initial_hash_; // hash from Subset to Element, which
|
|
// represents the OutputStateId together
|
|
// with an extra weight and string. Subset
|
|
// is "initial representation". The extra
|
|
// weight and string is needed because after
|
|
// we convert to minimal representation and
|
|
// normalize, there may be an extra weight
|
|
// and string. Owns the pointers
|
|
// in its keys.
|
|
std::vector<OutputStateId>
|
|
queue_; // Queue of output-states to process. Starts with
|
|
// state 0, and increases and then (hopefully) decreases in length during
|
|
// determinization. LIFO queue (queue discipline doesn't really matter).
|
|
|
|
std::vector<std::pair<Label, Element> >
|
|
all_elems_tmp_; // temporary vector used in ProcessTransitions.
|
|
|
|
enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
|
|
|
|
std::vector<char> isymbol_or_final_; // A kind of cache; it says whether
|
|
// each state is (emitting or final) where emitting means it has at least one
|
|
// non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
|
|
|
|
LatticeStringRepository<IntType>
|
|
repository_; // defines a compact and fast way of
|
|
// storing sequences of labels.
|
|
};
|
|
|
|
// normally Weight would be LatticeWeight<float> (which has two floats),
|
|
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
|
|
template <class Weight, class IntType>
|
|
bool DeterminizeLattice(const Fst<ArcTpl<Weight> > &ifst,
|
|
MutableFst<ArcTpl<Weight> > *ofst,
|
|
DeterminizeLatticeOptions opts, bool *debug_ptr) {
|
|
ofst->SetInputSymbols(ifst.InputSymbols());
|
|
ofst->SetOutputSymbols(ifst.OutputSymbols());
|
|
LatticeDeterminizer<Weight, IntType> det(ifst, opts);
|
|
if (!det.Determinize(debug_ptr)) return false;
|
|
det.Output(ofst);
|
|
return true;
|
|
}
|
|
|
|
// normally Weight would be LatticeWeight<float> (which has two floats),
|
|
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
|
|
template <class Weight, class IntType>
|
|
bool DeterminizeLattice(
|
|
const Fst<ArcTpl<Weight> > &ifst,
|
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
|
DeterminizeLatticeOptions opts, bool *debug_ptr) {
|
|
ofst->SetInputSymbols(ifst.InputSymbols());
|
|
ofst->SetOutputSymbols(ifst.OutputSymbols());
|
|
LatticeDeterminizer<Weight, IntType> det(ifst, opts);
|
|
if (!det.Determinize(debug_ptr)) return false;
|
|
det.Output(ofst);
|
|
return true;
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_INL_H_
|