You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/kaldi/fstext/determinize-star-inl.h

1205 lines
45 KiB

// fstext/determinize-star-inl.h
// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky
// 2015 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
#define KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
// Do not include this file directly. It is included by determinize-star.h
#include <algorithm>
#include <climits>
#include <deque>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
using std::unordered_map;
#include "base/kaldi-error.h"
namespace fst {
// This class maps back and forth from/to integer id's to sequences of strings.
// used in determinization algorithm.
template <class Label, class StringId>
class StringRepository {
// Label and StringId are both integer types, possibly the same.
// This is a utility that maps back and forth between a vector<Label> and
// StringId representation of sequences of Labels. It is to save memory, and
// to save compute. We treat sequences of length zero and one separately, for
// efficiency.
public:
class VectorKey { // Hash function object.
public:
size_t operator()(const std::vector<Label> *vec) const {
assert(vec != NULL);
size_t hash = 0, factor = 1;
for (typename std::vector<Label>::const_iterator it = vec->begin();
it != vec->end(); it++) {
hash += factor * (*it);
factor *= 103333; // just an arbitrary prime number.
}
return hash;
}
};
class VectorEqual { // Equality-operator function object.
public:
size_t operator()(const std::vector<Label> *vec1,
const std::vector<Label> *vec2) const {
return (*vec1 == *vec2);
}
};
typedef unordered_map<const std::vector<Label> *, StringId, VectorKey,
VectorEqual>
MapType;
StringId IdOfEmpty() { return no_symbol; }
StringId IdOfLabel(Label l) {
if (l >= 0 && l <= (Label)single_symbol_range) {
return l + single_symbol_start;
} else {
// l is out of the allowed range so we have to treat it as a sequence of
// length one. Should be v. rare.
std::vector<Label> v;
v.push_back(l);
return IdOfSeqInternal(v);
}
}
StringId IdOfSeq(
const std::vector<Label> &v) { // also works for sizes 0 and 1.
size_t sz = v.size();
if (sz == 0)
return no_symbol;
else if (v.size() == 1)
return IdOfLabel(v[0]);
else
return IdOfSeqInternal(v);
}
inline bool IsEmptyString(StringId id) { return id == no_symbol; }
void SeqOfId(StringId id, std::vector<Label> *v) {
if (id == no_symbol) {
v->clear();
} else if (id >= single_symbol_start) {
v->resize(1);
(*v)[0] = id - single_symbol_start;
} else {
assert(static_cast<size_t>(id) < vec_.size());
*v = *(vec_[id]);
}
}
StringId RemovePrefix(StringId id, size_t prefix_len) {
if (prefix_len == 0) {
return id;
} else {
std::vector<Label> v;
SeqOfId(id, &v);
size_t sz = v.size();
assert(sz >= prefix_len);
std::vector<Label> v_noprefix(sz - prefix_len);
for (size_t i = 0; i < sz - prefix_len; i++)
v_noprefix[i] = v[i + prefix_len];
return IdOfSeq(v_noprefix);
}
}
StringRepository() {
// The following are really just constants but don't want to complicate
// compilation so make them class variables. Due to the brokenness of
// <limits>, they can't be accessed as constants.
string_end = (std::numeric_limits<StringId>::max() / 2) -
1; // all hash values must be <= this.
no_symbol = (std::numeric_limits<StringId>::max() /
2); // reserved for empty sequence.
single_symbol_start = (std::numeric_limits<StringId>::max() / 2) + 1;
single_symbol_range =
std::numeric_limits<StringId>::max() - single_symbol_start;
}
void Destroy() {
for (typename std::vector<std::vector<Label> *>::iterator iter =
vec_.begin();
iter != vec_.end(); ++iter)
delete *iter;
std::vector<std::vector<Label> *> tmp_vec;
tmp_vec.swap(vec_);
MapType tmp_map;
tmp_map.swap(map_);
}
~StringRepository() { Destroy(); }
private:
KALDI_DISALLOW_COPY_AND_ASSIGN(StringRepository);
StringId IdOfSeqInternal(const std::vector<Label> &v) {
typename MapType::iterator iter = map_.find(&v);
if (iter != map_.end()) {
return iter->second;
} else { // must add it to map.
StringId this_id = (StringId)vec_.size();
std::vector<Label> *v_new = new std::vector<Label>(v);
vec_.push_back(v_new);
map_[v_new] = this_id;
assert(this_id < string_end); // or we used up the labels.
return this_id;
}
}
std::vector<std::vector<Label> *> vec_;
MapType map_;
static const StringId string_start =
(StringId)0; // This must not change. It's assumed.
StringId string_end; // = (numeric_limits<StringId>::max() / 2) - 1; // all
// hash values must be <= this.
StringId no_symbol; // = (numeric_limits<StringId>::max() / 2); // reserved
// for empty sequence.
StringId
single_symbol_start; // = (numeric_limits<StringId>::max() / 2) + 1;
StringId single_symbol_range; // = numeric_limits<StringId>::max() -
// single_symbol_start;
};
template <class F>
class DeterminizerStar {
typedef typename F::Arc Arc;
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
// correspondence between our states and the states in ofst. If destroy ==
// true, release memory as we go (but we cannot output again).
void Output(MutableFst<GallicArc<Arc> > *ofst, bool destroy = true);
// Output to standard FST. We will create extra states to handle sequences of
// symbols on the output. If destroy == true, release memory as we go (but we
// cannot output again).
void Output(MutableFst<Arc> *ofst, bool destroy = true);
// Initializer. After initializing the object you will typically call
// Determinize() and then one of the Output functions.
DeterminizerStar(const Fst<Arc> &ifst, float delta = kDelta,
int max_states = -1, bool allow_partial = false)
: ifst_(ifst.Copy()),
delta_(delta),
max_states_(max_states),
determinized_(false),
allow_partial_(allow_partial),
is_partial_(false),
equal_(delta),
hash_(ifst.Properties(kExpanded, false)
? down_cast<const ExpandedFst<Arc> *, const Fst<Arc> >(&ifst)
->NumStates() /
2 +
3
: 20,
hasher_, equal_),
epsilon_closure_(ifst_, max_states, &repository_, delta) {}
void Determinize(bool *debug_ptr) {
assert(!determinized_);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_".
InputStateId start_id = ifst_->Start();
if (start_id == kNoStateId) {
determinized_ = true;
return; // Nothing to do.
} else { // Insert start state into hash and queue.
Element elem;
elem.state = start_id;
elem.weight = Weight::One();
elem.string = repository_.IdOfEmpty(); // Id of empty sequence.
std::vector<Element> vec;
vec.push_back(elem);
OutputStateId cur_id = SubsetToStateId(vec);
assert(cur_id == 0 && "Do not call Determinize twice.");
}
while (!Q_.empty()) {
std::pair<std::vector<Element> *, OutputStateId> cur_pair = Q_.front();
Q_.pop_front();
ProcessSubset(cur_pair);
if (debug_ptr && *debug_ptr) Debug(); // will exit.
if (max_states_ > 0 && output_arcs_.size() > max_states_) {
if (allow_partial_ == false) {
KALDI_ERR << "Determinization aborted since passed " << max_states_
<< " states";
} else {
KALDI_WARN << "Determinization terminated since passed "
<< max_states_
<< " states, partial results will be generated";
is_partial_ = true;
break;
}
}
}
determinized_ = true;
}
bool IsPartial() { return is_partial_; }
// frees all except output_arcs_, which contains the important info
// we need to output.
void FreeMostMemory() {
if (ifst_) {
delete ifst_;
ifst_ = NULL;
}
for (typename SubsetHash::iterator iter = hash_.begin();
iter != hash_.end(); ++iter)
delete iter->first;
SubsetHash tmp;
tmp.swap(hash_);
}
~DeterminizerStar() { FreeMostMemory(); }
private:
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId InputStateId;
typedef typename Arc::StateId
OutputStateId; // same as above but distinguish states in output Fst.
typedef typename Arc::Label StringId; // Id type used in the StringRepository
typedef StringRepository<Label, StringId> StringRepositoryType;
// Element of a subset [of original states]
struct Element {
InputStateId state;
StringId string;
Weight weight;
bool operator!=(const Element &other) const {
return (state != other.state || string != other.string ||
weight != other.weight);
}
};
// Arcs in the format we temporarily create in this class (a representation,
// essentially of a Gallic Fst).
struct TempArc {
Label ilabel;
StringId ostring; // Look it up in the StringRepository, it's a sequence of
// Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that
// is order-dependent. However the weights are not included in the hashing
// function-- we hash subsets that differ only in weight to the same key. This
// is not optimal in terms of the O(N) performance but typically if we have a
// lot of determinized states that differ only in weight then the input
// probably was pathological in some way, or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple
// cases.
// Instead we apply the delta when comparing subsets for equality, and allow a
// small difference.
class SubsetKey {
public:
size_t operator()(const std::vector<Element> *subset)
const { // hashes only the state and string.
size_t hash = 0, factor = 1;
for (typename std::vector<Element>::const_iterator iter = subset->begin();
iter != subset->end(); ++iter) {
hash *= factor;
hash += iter->state + 103333 * iter->string;
factor *= 23531; // these numbers are primes.
}
return hash;
}
};
// This is the equality operator on subsets. It checks for exact match on
// state-id and string, and approximate match on weights.
class SubsetEqual {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state || iter1->string != iter2->string ||
!ApproxEqual(iter1->weight, iter2->weight, delta_))
return false;
}
return true;
}
float delta_;
explicit SubsetEqual(float delta) : delta_(delta) {}
SubsetEqual() : delta_(kDelta) {}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates {
public:
bool operator()(const std::vector<Element> *s1,
const std::vector<Element> *s2) const {
size_t sz = s1->size();
assert(sz >= 0);
if (sz != s2->size()) return false;
typename std::vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(),
iter2 = s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state) return false;
}
return true;
}
};
// Define the hash type we use to store subsets.
typedef unordered_map<const std::vector<Element> *, OutputStateId, SubsetKey,
SubsetEqual>
SubsetHash;
class EpsilonClosure {
public:
EpsilonClosure(const Fst<Arc> *ifst, int max_states,
StringRepository<Label, StringId> *repository, float delta)
: ifst_(ifst),
max_states_(max_states),
repository_(repository),
delta_(delta) {}
// This function computes epsilon closure of subset of states by following
// epsilon links. Called by ProcessSubset. Has no side effects except on the
// repository.
void GetEpsilonClosure(const std::vector<Element> &input_subset,
std::vector<Element> *output_subset);
private:
struct EpsilonClosureInfo {
EpsilonClosureInfo() {}
EpsilonClosureInfo(const Element &e, const Weight &w, bool i)
: element(e), weight_to_process(w), in_queue(i) {}
// the weight in the Element struct is the total current weight
// that has been processed already
Element element;
// this stores the weight that we haven't processed (propagated)
Weight weight_to_process;
// whether "this" struct is in the queue
// we store the info here so that we don't have to look it up every time
bool in_queue;
bool operator<(const EpsilonClosureInfo &other) const {
return this->element.state < other.element.state;
}
};
// to further speed up EpsilonClosure() computation, we have 2 queues
// the 2nd queue is used when we first iterate over the input set -
// if queue_2_.empty() then we directly set output_set equal to input_set
// and return immediately
// Since Epsilon arcs are relatively rare, this way we could efficiently
// detect the epsilon-free case, without having to waste our computation
// e.g. allocating the EpsilonClosureInfo structure; this also lets us do a
// level-by-level traversal, which could avoid some (unfortunately not all)
// duplicate computation if epsilons form a DAG that is not a tree
//
// We put the queues here for better efficiency for memory allocation
std::deque<typename Arc::StateId> queue_;
std::vector<Element> queue_2_;
// the following 2 structures together form our *virtual "map"*
// basically we need a map from state_id to EpsilonClosureInfo that operates
// in O(1) time, while still takes relatively small mem, and this does it
// well for efficiency we don't clear id_to_index_ of its outdated
// information As a result each time we do a look-up, we need to check if
// (ecinfo_[id_to_index_[id]].element.state == id) Yet this is still faster
// than using a std::map<StateId, EpsilonClosureInfo>
std::vector<int> id_to_index_;
// unlike id_to_index_, we clear the content of ecinfo_ each time we call
// EpsilonClosure(). This needed because we need an efficient way to
// traverse the virtual map - it is just too costly to traverse the
// id_to_index_ vector.
std::vector<EpsilonClosureInfo> ecinfo_;
// Add one element (elem) into cur_subset
// it also adds the necessary stuff to queue_, set the correct weight
void AddOneElement(const Element &elem, const Weight &unprocessed_weight);
// Sub-routine that we call in EpsilonClosure()
// It takes the current "unprocessed_weight" and propagate it to the
// states accessible from elem.state by an epsilon arc
// and add the results to cur_subset.
// save_to_queue_2 is set true when we iterate over the initial subset
// - then we save it to queue_2 s.t. if it's empty, we directly return
// the input set
void ExpandOneElement(const Element &elem, bool sorted,
const Weight &unprocessed_weight,
bool save_to_queue_2 = false);
// no pointers below would take the ownership
const Fst<Arc> *ifst_;
int max_states_;
StringRepository<Label, StringId> *repository_;
float delta_;
};
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and output_arcs_.
void ProcessFinal(const std::vector<Element> &closed_subset,
OutputStateId state) {
// processes final-weights for this subset.
bool is_final = false;
StringId final_string = 0; // = 0 to keep compiler happy.
Weight final_weight =
Weight::One(); // This value will never be accessed, and
// we just set it to avoid spurious compiler warnings. We avoid setting it
// to Zero() because floating-point infinities can sometimes generate
// interrupts and slow things down.
typename std::vector<Element>::const_iterator iter = closed_subset.begin(),
end = closed_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
Weight this_final_weight = ifst_->Final(elem.state);
if (this_final_weight != Weight::Zero()) {
if (!is_final) { // first final-weight
final_string = elem.string;
final_weight = Times(elem.weight, this_final_weight);
is_final = true;
} else { // already have one.
if (final_string != elem.string) {
KALDI_ERR << "FST was not functional -> not determinizable";
}
final_weight =
Plus(final_weight, Times(elem.weight, this_final_weight));
}
}
}
if (is_final) {
// store final weights in TempArc structure, just like a transition.
TempArc temp_arc;
temp_arc.ilabel = 0;
temp_arc.nextstate =
kNoStateId; // special marker meaning "final weight".
temp_arc.ostring = final_string;
temp_arc.weight = final_weight;
output_arcs_[state].push_back(temp_arc);
}
}
// ProcessTransition is called from "ProcessTransitions". Broken out for
// clarity. Has side effects on output_arcs_, and (via SubsetToStateId), Q_
// and hash_.
void ProcessTransition(OutputStateId state, Label ilabel,
std::vector<Element> *subset);
// "less than" operator for pair<Label, Element>. Used in
// ProcessTransitions. Lexicographical order, with comparing the state only
// for "Element".
class PairComparator {
public:
inline bool operator()(const std::pair<Label, Element> &p1,
const std::pair<Label, Element> &p2) {
if (p1.first < p2.first) {
return true;
} else if (p1.first > p2.first) {
return false;
} else {
return p1.second.state < p2.second.state;
}
}
};
// ProcessTransitions handles transitions out of this subset of states.
// Ignores epsilon transitions (epsilon closure already handled that).
// Does not consider final states. Breaks the transitions up by ilabel,
// and creates a new transition in determinized FST, for each ilabel.
// Does this by creating a big vector of pairs <Label, Element> and then
// sorting them using a lexicographical ordering, and calling
// ProcessTransition for each range with the same ilabel. Side effects on
// repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
void ProcessTransitions(const std::vector<Element> &closed_subset,
OutputStateId state) {
std::vector<std::pair<Label, Element> > all_elems;
{ // Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions
// out of all states in "closed_subset".
typename std::vector<Element>::const_iterator iter =
closed_subset.begin(),
end = closed_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel !=
0) { // Non-epsilon transition -- ignore epsilons here.
std::pair<Label, Element> this_pr;
this_pr.first = arc.ilabel;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) { // output epsilon-- this is simple case so
// handle separately for efficiency
next_elem.string = elem.string;
} else {
std::vector<Label> seq;
repository_.SeqOfId(elem.string, &seq);
seq.push_back(arc.olabel);
next_elem.string = repository_.IdOfSeq(seq);
}
all_elems.push_back(this_pr);
}
}
}
}
PairComparator pc;
std::sort(all_elems.begin(), all_elems.end(), pc);
// now sorted first on input label, then on state.
typedef typename std::vector<std::pair<Label, Element> >::const_iterator
PairIter;
PairIter cur = all_elems.begin(), end = all_elems.end();
std::vector<Element> this_subset;
while (cur != end) {
// Process ranges that share the same input symbol.
Label ilabel = cur->first;
this_subset.clear();
while (cur != end && cur->first == ilabel) {
this_subset.push_back(cur->second);
cur++;
}
// We now have a subset for this ilabel.
ProcessTransition(state, ilabel, &this_subset);
}
}
// SubsetToStateId converts a subset (vector of Elements) to a StateId in the
// output fst. This is a hash lookup; if no such state exists, it adds a new
// state to the hash and adds a new pair to the queue. Side effects on hash_
// and Q_, and on output_arcs_ [just affects the size].
OutputStateId SubsetToStateId(
const std::vector<Element> &subset) { // may add the subset to the queue.
typedef typename SubsetHash::iterator IterType;
IterType iter = hash_.find(&subset);
if (iter == hash_.end()) { // was not there.
std::vector<Element> *new_subset = new std::vector<Element>(subset);
OutputStateId new_state_id = (OutputStateId)output_arcs_.size();
bool ans =
hash_
.insert(std::pair<const std::vector<Element> *, OutputStateId>(
new_subset, new_state_id))
.second;
assert(ans);
output_arcs_.push_back(std::vector<TempArc>());
if (allow_partial_ == false) {
// If --allow-partial is not requested, we do the old way.
Q_.push_front(std::pair<std::vector<Element> *, OutputStateId>(
new_subset, new_state_id));
} else {
// If --allow-partial is requested, we do breadth first search. This
// ensures that when we return partial results, we return the states
// that are reachable by the fewest steps from the start state.
Q_.push_back(std::pair<std::vector<Element> *, OutputStateId>(
new_subset, new_state_id));
}
return new_state_id;
} else {
return iter->second; // the OutputStateId.
}
}
// ProcessSubset does the processing of a determinized state, i.e. it creates
// transitions out of it and adds new determinized states to the queue if
// necessary. The first stage is "EpsilonClosure" (follow epsilons to get a
// possibly larger set of (states, weights)). After that we ignore epsilons.
// We process the final-weight of the state, and then handle transitions out
// (this may add more determinized states to the queue).
void ProcessSubset(
const std::pair<std::vector<Element> *, OutputStateId> &pair) {
const std::vector<Element> *subset = pair.first;
OutputStateId state = pair.second;
std::vector<Element> closed_subset; // subset after epsilon closure.
epsilon_closure_.GetEpsilonClosure(*subset, &closed_subset);
// Now follow non-epsilon arcs [and also process final states]
ProcessFinal(closed_subset, state);
// Now handle transitions out of these states.
ProcessTransitions(closed_subset, state);
}
void Debug();
KALDI_DISALLOW_COPY_AND_ASSIGN(DeterminizerStar);
std::deque<std::pair<std::vector<Element> *, OutputStateId> >
Q_; // queue of subsets to be processed.
std::vector<std::vector<TempArc> >
output_arcs_; // essentially an FST in our format.
const Fst<Arc> *ifst_;
float delta_;
int max_states_;
bool determinized_; // used to check usage.
bool allow_partial_; // output paritial results or not
bool is_partial_; // if we get partial results or not
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual
equal_; // object that compares subsets-- only data member is delta_.
SubsetHash hash_; // hash from Subset to StateId in final Fst.
StringRepository<Label, StringId>
repository_; // associate integer id's with sequences of labels.
EpsilonClosure epsilon_closure_;
};
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<typename F::Arc> *ofst, float delta,
bool *debug_ptr, int max_states, bool allow_partial) {
ofst->SetOutputSymbols(ifst.OutputSymbols());
ofst->SetInputSymbols(ifst.InputSymbols());
DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
det.Determinize(debug_ptr);
det.Output(ofst);
return det.IsPartial();
}
template <class F>
bool DeterminizeStar(F &ifst, // NOLINT
MutableFst<GallicArc<typename F::Arc> > *ofst, float delta,
bool *debug_ptr, int max_states, bool allow_partial) {
ofst->SetOutputSymbols(ifst.InputSymbols());
ofst->SetInputSymbols(ifst.InputSymbols());
DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
det.Determinize(debug_ptr);
det.Output(ofst);
return det.IsPartial();
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::GetEpsilonClosure(
const std::vector<Element> &input_subset,
std::vector<Element> *output_subset) {
ecinfo_.resize(0);
size_t size = input_subset.size();
// find whether input fst is known to be sorted in input label.
bool sorted =
((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
// size is still the input_subset.size()
for (size_t i = 0; i < size; i++) {
ExpandOneElement(input_subset[i], sorted, input_subset[i].weight, true);
}
size_t s = queue_2_.size();
if (s == 0) {
*output_subset = input_subset;
return;
} else {
// queue_2 not empty. Need to create the vector<info>
for (size_t i = 0; i < size; i++) {
// the weight has not been processed yet,
// so put all of them in the "weight_to_process"
ecinfo_.push_back(
EpsilonClosureInfo(input_subset[i], input_subset[i].weight, false));
ecinfo_.back().element.weight = Weight::Zero(); // clear the weight
if (id_to_index_.size() < input_subset[i].state + 1) {
id_to_index_.resize(2 * input_subset[i].state + 1, -1);
}
id_to_index_[input_subset[i].state] = ecinfo_.size() - 1;
}
}
{
Element elem;
elem.weight = Weight::Zero();
for (size_t i = 0; i < s; i++) {
elem.state = queue_2_[i].state;
elem.string = queue_2_[i].string;
AddOneElement(elem, queue_2_[i].weight);
}
queue_2_.resize(0);
}
int counter = 0; // relates to max-states option, used for test.
while (!queue_.empty()) {
InputStateId id = queue_.front();
// no need to check validity of the index
// since anything in the queue we are sure they're in the "virtual set"
int index = id_to_index_[id];
EpsilonClosureInfo &info = ecinfo_[index];
Element &elem = info.element;
Weight unprocessed_weight = info.weight_to_process;
elem.weight = Plus(elem.weight, unprocessed_weight);
info.weight_to_process = Weight::Zero();
info.in_queue = false;
queue_.pop_front();
if (max_states_ > 0 && counter++ > max_states_) {
KALDI_ERR << "Determinization aborted since looped more than "
<< max_states_ << " times during epsilon closure";
}
// generally we need to be careful about iterator-invalidation problem
// here we pass a reference (elem), which could be an issue.
// In the beginning of ExpandOneElement, we make a copy of elem.string
// to avoid that issue
ExpandOneElement(elem, sorted, unprocessed_weight);
}
{
// this sorting is based on StateId
sort(ecinfo_.begin(), ecinfo_.end());
output_subset->clear();
size = ecinfo_.size();
output_subset->reserve(size);
for (size_t i = 0; i < size; i++) {
EpsilonClosureInfo &info = ecinfo_[i];
if (info.weight_to_process != Weight::Zero()) {
info.element.weight = Plus(info.element.weight, info.weight_to_process);
}
output_subset->push_back(info.element);
}
}
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::AddOneElement(
const Element &elem, const Weight &unprocessed_weight) {
// first we try to find the element info in the ecinfo_ vector
int index = -1;
if (elem.state < id_to_index_.size()) {
index = id_to_index_[elem.state];
}
if (index != -1) {
if (index >= ecinfo_.size()) {
index = -1;
} else if (ecinfo_[index].element.state != elem.state) {
// since ecinfo_ might store outdated information, we need to check
index = -1;
}
}
if (index == -1) {
// was no such StateId: insert and add to queue.
ecinfo_.push_back(EpsilonClosureInfo(elem, unprocessed_weight, true));
size_t size = id_to_index_.size();
if (size < elem.state + 1) {
// double the size to reduce memory operations
id_to_index_.resize(2 * elem.state + 1, -1);
}
id_to_index_[elem.state] = ecinfo_.size() - 1;
queue_.push_back(elem.state);
} else { // one is already there. Add weights.
EpsilonClosureInfo &info = ecinfo_[index];
if (info.element.string != elem.string) {
// Non-functional FST.
std::ostringstream ss;
ss << "FST was not functional -> not determinizable.";
{ // Print some debugging information. Can be helpful to debug
// the inputs when FSTs are mysteriously non-functional.
std::vector<Label> tmp_seq;
repository_->SeqOfId(info.element.string, &tmp_seq);
ss << "\nFirst string:";
for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
ss << "\nSecond string:";
repository_->SeqOfId(elem.string, &tmp_seq);
for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
}
KALDI_ERR << ss.str();
}
info.weight_to_process = Plus(info.weight_to_process, unprocessed_weight);
if (!info.in_queue) {
// this is because the code in "else" below: the
// iter->second.weight_to_process might not be Zero()
Weight weight = Plus(info.element.weight, info.weight_to_process);
// What is done below is, we propagate the weight (by adding them
// to the queue only when the change is big enough;
// otherwise we just store the weight, until before returning
// we add the element.weight and weight_to_process together
if (!ApproxEqual(weight, info.element.weight, delta_)) {
// add extra part of weight to queue.
info.in_queue = true;
queue_.push_back(elem.state);
}
}
}
}
template <class F>
void DeterminizerStar<F>::EpsilonClosure::ExpandOneElement(
const Element &elem, bool sorted, const Weight &unprocessed_weight,
bool save_to_queue_2) {
StringId str =
elem.string; // copy it here because there is an iterator-
// - invalidation problem (it really happens for some FSTs)
// now we are going to propagate the "unprocessed_weight"
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (sorted && arc.ilabel > 0) {
break;
// Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
}
if (arc.ilabel != 0) {
continue; // we only process epsilons here
}
Element next_elem;
next_elem.state = arc.nextstate;
next_elem.weight = Weight::Zero();
Weight next_unprocessed_weight = Times(unprocessed_weight, arc.weight);
// now must append strings
if (arc.olabel == 0) {
next_elem.string = str;
} else {
std::vector<Label> seq;
repository_->SeqOfId(str, &seq);
if (arc.olabel != 0) seq.push_back(arc.olabel);
next_elem.string = repository_->IdOfSeq(seq);
}
if (save_to_queue_2) {
next_elem.weight = next_unprocessed_weight;
queue_2_.push_back(next_elem);
} else {
AddOneElement(next_elem, next_unprocessed_weight);
}
}
}
template <class F>
void DeterminizerStar<F>::Output(MutableFst<GallicArc<Arc> > *ofst,
bool destroy) {
assert(determinized_);
if (destroy) determinized_ = false;
typedef GallicWeight<Label, Weight> ThisGallicWeight;
typedef typename Arc::StateId StateId;
if (destroy) FreeMostMemory();
StateId nStates = static_cast<StateId>(output_arcs_.size());
ofst->DeleteStates();
ofst->SetStart(kNoStateId);
if (nStates == 0) {
return;
}
for (StateId s = 0; s < nStates; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
// now process transitions.
for (StateId this_state = 0; this_state < nStates; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
GallicArc<Arc> new_arc;
std::vector<Label> seq;
repository_.SeqOfId(temp_arc.ostring, &seq);
StringWeight<Label, STRING_LEFT> string_weight;
for (size_t i = 0; i < seq.size(); i++) string_weight.PushBack(seq[i]);
ThisGallicWeight gallic_weight(string_weight, temp_arc.weight);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state, gallic_weight);
} else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate;
new_arc.ilabel = temp_arc.ilabel;
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = gallic_weight; // includes string and weight.
ofst->AddArc(this_state, new_arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
temp.swap(this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
temp.swap(output_arcs_);
}
}
template <class F>
void DeterminizerStar<F>::Output(MutableFst<Arc> *ofst, bool destroy) {
assert(determinized_);
if (destroy) determinized_ = false;
// Outputs to standard fst.
OutputStateId num_states = static_cast<OutputStateId>(output_arcs_.size());
if (destroy) FreeMostMemory();
ofst->DeleteStates();
if (num_states == 0) {
ofst->SetStart(kNoStateId);
return;
}
// Add basic states-- but will add extra ones to account for strings on
// output.
for (OutputStateId s = 0; s < num_states; s++) {
OutputStateId news = ofst->AddState();
assert(news == s);
}
ofst->SetStart(0);
for (OutputStateId this_state = 0; this_state < num_states; this_state++) {
std::vector<TempArc> &this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
std::vector<Label> seq;
repository_.SeqOfId(temp_arc.ostring, &seq);
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings as
// labels. Put the weight on the first arc.
OutputStateId cur_state = this_state;
for (size_t i = 0; i < seq.size(); i++) {
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = 0; // epsilon.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state,
(seq.size() == 0 ? temp_arc.weight : Weight::One()));
} else { // Really an arc.
OutputStateId cur_state = this_state;
// Have to be careful with this integer comparison (i+1 < seq.size())
// because unsigned. i < seq.size()-1 could fail for zero-length
// sequences.
for (size_t i = 0; i + 1 < seq.size(); i++) {
// for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = (i == 0 ? temp_arc.ilabel
: 0); // put ilabel on first element of seq.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
// Add the final arc in the sequence.
Arc arc;
arc.nextstate = temp_arc.nextstate;
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
ofst->AddArc(cur_state, arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) {
std::vector<TempArc> temp;
temp.swap(this_vec);
}
}
if (destroy) {
std::vector<std::vector<TempArc> > temp;
temp.swap(output_arcs_);
repository_.Destroy();
}
}
template <class F>
void DeterminizerStar<F>::ProcessTransition(OutputStateId state, Label ilabel,
std::vector<Element> *subset) {
// At input, "subset" may contain duplicates for a given dest state (but in
// sorted order). This function removes duplicates from "subset", normalizes
// it, and adds a transition to the dest. state (possibly affecting Q_ and
// hash_, if state did not exist).
typedef typename std::vector<Element>::iterator IterType;
{ // This block makes the subset have one unique Element per state, adding
// the weights.
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
size_t num_out = 0;
// Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put
// an element, cur_in points to location of next element we want to
// process.
if (cur_in != cur_out) *cur_out = *cur_in;
cur_in++;
while (cur_in != end &&
cur_in->state == cur_out->state) { // merge elements.
if (cur_in->string != cur_out->string) {
KALDI_ERR << "FST was not functional -> not determinizable";
}
cur_out->weight = Plus(cur_out->weight, cur_in->weight);
cur_in++;
}
cur_out++;
num_out++;
}
subset->resize(num_out);
}
StringId common_str;
Weight tot_weight;
{ // This block computes common_str and tot_weight (essentially: the common
// divisor)
// and removes them from the elements.
std::vector<Label> seq;
IterType begin = subset->begin(), iter, end = subset->end();
{ // This block computes "seq", which is the common prefix, and
// "common_str",
// which is the StringId version of "seq".
std::vector<Label> tmp_seq;
for (iter = begin; iter != end; ++iter) {
if (iter == begin) {
repository_.SeqOfId(iter->string, &seq);
} else {
repository_.SeqOfId(iter->string, &tmp_seq);
if (tmp_seq.size() < seq.size())
seq.resize(tmp_seq.size()); // size of shortest one.
for (size_t i = 0; i < seq.size();
i++) // seq.size() is the shorter one at this point.
if (tmp_seq[i] != seq[i]) seq.resize(i);
}
if (seq.size() == 0) break; // will not get any prefix.
}
common_str = repository_.IdOfSeq(seq);
}
{ // This block computes "tot_weight".
iter = begin;
tot_weight = iter->weight;
for (++iter; iter != end; ++iter)
tot_weight = Plus(tot_weight, iter->weight);
}
// Now divide out common stuff from elements.
size_t prefix_len = seq.size();
for (iter = begin; iter != end; ++iter) {
iter->weight = Divide(iter->weight, tot_weight);
iter->string = repository_.RemovePrefix(iter->string, prefix_len);
}
}
// Now add an arc to the state that the subset represents.
// We may create a new state id for this (in SubsetToStateId).
TempArc temp_arc;
temp_arc.ilabel = ilabel;
temp_arc.nextstate =
SubsetToStateId(*subset); // may or may not really add the subset.
temp_arc.ostring = common_str;
temp_arc.weight = tot_weight;
output_arcs_[state].push_back(temp_arc); // record the arc.
}
template <class F>
void DeterminizerStar<F>::Debug() {
// this function called if you send a signal
// SIGUSR1 to the process (and it's caught by the handler in
// fstdeterminizestar). It prints out some traceback
// info and exits.
KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
// free up memory from the hash as we need a little memory
{
SubsetHash hash_tmp;
std::swap(hash_tmp, hash_);
}
if (output_arcs_.size() <= 2) {
KALDI_ERR << "Nothing to trace back";
}
size_t max_state = output_arcs_.size() - 2; // don't take the last
// one as we might be halfway into constructing it.
std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId);
for (size_t i = 0; i < max_state; i++) {
for (size_t j = 0; j < output_arcs_[i].size(); j++) {
OutputStateId nextstate = output_arcs_[i][j].nextstate;
// Always find an earlier-numbered predecessor; this
// is always possible because of the way the algorithm
// works.
if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
}
}
std::vector<std::pair<Label, StringId> > traceback;
// 'traceback' is a pair of (ilabel, olabel-seq).
OutputStateId cur_state = max_state; // A recently constructed state.
while (cur_state != 0 && cur_state != kNoStateId) {
OutputStateId last_state = predecessor[cur_state];
std::pair<Label, StringId> p;
size_t i;
for (i = 0; i < output_arcs_[last_state].size(); i++) {
if (output_arcs_[last_state][i].nextstate == cur_state) {
p.first = output_arcs_[last_state][i].ilabel;
p.second = output_arcs_[last_state][i].ostring;
traceback.push_back(p);
break;
}
}
KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
cur_state = last_state;
}
if (cur_state == kNoStateId)
KALDI_WARN << "Traceback did not reach start state "
<< "(possibly debug-code error)";
std::stringstream ss;
ss << "Traceback follows in format "
<< "ilabel (olabel olabel) ilabel (olabel) ... :";
for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
ss << ' ' << traceback[i].first << " ( ";
std::vector<Label> seq;
repository_.SeqOfId(traceback[i].second, &seq);
for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
ss << ')';
}
KALDI_ERR << ss.str();
}
} // namespace fst
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_