You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
319 lines
12 KiB
319 lines
12 KiB
3 years ago
|
// fstext/remove-eps-local-inl.h
|
||
|
|
||
|
// Copyright 2009-2011 Microsoft Corporation
|
||
|
// 2014 Johns Hopkins University (author: Daniel Povey
|
||
|
|
||
|
// See ../../COPYING for clarification regarding multiple authors
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||
|
// See the Apache 2 License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|
||
|
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|
||
|
|
||
|
#include <vector>
|
||
|
|
||
|
namespace fst {
|
||
|
|
||
|
template <class Weight>
|
||
|
struct ReweightPlusDefault {
|
||
|
inline Weight operator()(const Weight &a, const Weight &b) {
|
||
|
return Plus(a, b);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
struct ReweightPlusLogArc {
|
||
|
inline TropicalWeight operator()(const TropicalWeight &a,
|
||
|
const TropicalWeight &b) {
|
||
|
LogWeight a_log(a.Value()), b_log(b.Value());
|
||
|
return TropicalWeight(Plus(a_log, b_log).Value());
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <class Arc,
|
||
|
class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
|
||
|
class RemoveEpsLocalClass {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Label Label;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
public:
|
||
|
explicit RemoveEpsLocalClass(MutableFst<Arc> *fst) : fst_(fst) {
|
||
|
if (fst_->Start() == kNoStateId) return; // empty.
|
||
|
non_coacc_state_ = fst_->AddState();
|
||
|
InitNumArcs();
|
||
|
StateId num_states = fst_->NumStates();
|
||
|
for (StateId s = 0; s < num_states; s++)
|
||
|
for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
|
||
|
assert(CheckNumArcs());
|
||
|
Connect(fst); // remove inaccessible states.
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
MutableFst<Arc> *fst_;
|
||
|
StateId non_coacc_state_; // use this to delete arcs: make it nextstate
|
||
|
std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
|
||
|
// one if it's the start state.
|
||
|
std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
|
||
|
// plus one if it's a final state.
|
||
|
ReweightPlus reweight_plus_;
|
||
|
|
||
|
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
|
||
|
if (a.ilabel != 0 && b.ilabel != 0) return false;
|
||
|
if (a.olabel != 0 && b.olabel != 0) return false;
|
||
|
c->weight = Times(a.weight, b.weight);
|
||
|
c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
|
||
|
c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
|
||
|
c->nextstate = b.nextstate;
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
static bool CanCombineFinal(const Arc &a, Weight final_prob,
|
||
|
Weight *final_prob_out) {
|
||
|
if (a.ilabel != 0 || a.olabel != 0) {
|
||
|
return false;
|
||
|
} else {
|
||
|
*final_prob_out = Times(a.weight, final_prob);
|
||
|
return true;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void InitNumArcs() { // init num transitions in/out of each state.
|
||
|
StateId num_states = fst_->NumStates();
|
||
|
num_arcs_in_.resize(num_states);
|
||
|
num_arcs_out_.resize(num_states);
|
||
|
num_arcs_in_[fst_->Start()]++; // count start as trans in.
|
||
|
for (StateId s = 0; s < num_states; s++) {
|
||
|
if (fst_->Final(s) != Weight::Zero())
|
||
|
num_arcs_out_[s]++; // count final as transition.
|
||
|
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
num_arcs_in_[aiter.Value().nextstate]++;
|
||
|
num_arcs_out_[s]++;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
|
||
|
num_arcs_in_[fst_->Start()]--; // count start as trans in.
|
||
|
StateId num_states = fst_->NumStates();
|
||
|
for (StateId s = 0; s < num_states; s++) {
|
||
|
if (s == non_coacc_state_) continue;
|
||
|
if (fst_->Final(s) != Weight::Zero())
|
||
|
num_arcs_out_[s]--; // count final as transition.
|
||
|
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
if (aiter.Value().nextstate == non_coacc_state_) continue;
|
||
|
num_arcs_in_[aiter.Value().nextstate]--;
|
||
|
num_arcs_out_[s]--;
|
||
|
}
|
||
|
}
|
||
|
for (StateId s = 0; s < num_states; s++) {
|
||
|
assert(num_arcs_in_[s] == 0);
|
||
|
assert(num_arcs_out_[s] == 0);
|
||
|
}
|
||
|
return true; // always does this. so we can assert it w/o warnings.
|
||
|
}
|
||
|
|
||
|
inline void GetArc(StateId s, size_t pos, Arc *arc) const {
|
||
|
ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
|
||
|
aiter.Seek(pos);
|
||
|
*arc = aiter.Value();
|
||
|
}
|
||
|
|
||
|
inline void SetArc(StateId s, size_t pos, const Arc &arc) {
|
||
|
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
|
||
|
aiter.Seek(pos);
|
||
|
aiter.SetValue(arc);
|
||
|
}
|
||
|
|
||
|
void Reweight(StateId s, size_t pos, Weight reweight) {
|
||
|
// Reweight is called from RemoveEpsPattern1; it is a step we
|
||
|
// do to preserve stochasticity. This function multiplies the
|
||
|
// arc at (s, pos) by reweight and divides all the arcs [+final-prob]
|
||
|
// out of the next state by the same. This is only valid if
|
||
|
// the next state has only one arc in and is not the start state.
|
||
|
assert(reweight != Weight::Zero());
|
||
|
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
|
||
|
aiter.Seek(pos);
|
||
|
Arc arc = aiter.Value();
|
||
|
assert(num_arcs_in_[arc.nextstate] == 1);
|
||
|
arc.weight = Times(arc.weight, reweight);
|
||
|
aiter.SetValue(arc);
|
||
|
|
||
|
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
|
||
|
!aiter_next.Done(); aiter_next.Next()) {
|
||
|
Arc nextarc = aiter_next.Value();
|
||
|
if (nextarc.nextstate != non_coacc_state_) {
|
||
|
nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
|
||
|
aiter_next.SetValue(nextarc);
|
||
|
}
|
||
|
}
|
||
|
Weight final = fst_->Final(arc.nextstate);
|
||
|
if (final != Weight::Zero()) {
|
||
|
fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RemoveEpsPattern1 applies where this arc, which is not a
|
||
|
// self-loop, enters a state which has only one input transition
|
||
|
// [and is not the start state], and has multiple output
|
||
|
// transitions [counting being the final-state as a final-transition].
|
||
|
|
||
|
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
|
||
|
const StateId nextstate = arc.nextstate;
|
||
|
Weight total_removed = Weight::Zero(),
|
||
|
total_kept = Weight::Zero(); // totals out of nextstate.
|
||
|
std::vector<Arc> arcs_to_add; // to add to state s.
|
||
|
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
|
||
|
!aiter_next.Done(); aiter_next.Next()) {
|
||
|
Arc nextarc = aiter_next.Value();
|
||
|
if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
|
||
|
Arc combined;
|
||
|
if (CanCombineArcs(arc, nextarc, &combined)) {
|
||
|
total_removed = reweight_plus_(total_removed, nextarc.weight);
|
||
|
num_arcs_out_[nextstate]--;
|
||
|
num_arcs_in_[nextarc.nextstate]--;
|
||
|
nextarc.nextstate = non_coacc_state_;
|
||
|
aiter_next.SetValue(nextarc);
|
||
|
arcs_to_add.push_back(combined);
|
||
|
} else {
|
||
|
total_kept = reweight_plus_(total_kept, nextarc.weight);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
{ // now final-state.
|
||
|
Weight next_final = fst_->Final(nextstate);
|
||
|
if (next_final != Weight::Zero()) {
|
||
|
Weight new_final;
|
||
|
if (CanCombineFinal(arc, next_final, &new_final)) {
|
||
|
total_removed = reweight_plus_(total_removed, next_final);
|
||
|
if (fst_->Final(s) == Weight::Zero())
|
||
|
num_arcs_out_[s]++; // final is counted as arc.
|
||
|
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
|
||
|
num_arcs_out_[nextstate]--;
|
||
|
fst_->SetFinal(nextstate, Weight::Zero());
|
||
|
} else {
|
||
|
total_kept = reweight_plus_(total_kept, next_final);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (total_removed != Weight::Zero()) { // did something...
|
||
|
if (total_kept == Weight::Zero()) { // removed everything: remove arc.
|
||
|
num_arcs_out_[s]--;
|
||
|
num_arcs_in_[arc.nextstate]--;
|
||
|
arc.nextstate = non_coacc_state_;
|
||
|
SetArc(s, pos, arc);
|
||
|
} else {
|
||
|
// Have to reweight.
|
||
|
Weight total = reweight_plus_(total_removed, total_kept);
|
||
|
Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
|
||
|
Reweight(s, pos, reweight);
|
||
|
}
|
||
|
}
|
||
|
// Now add the arcs we were going to add.
|
||
|
for (size_t i = 0; i < arcs_to_add.size(); i++) {
|
||
|
num_arcs_out_[s]++;
|
||
|
num_arcs_in_[arcs_to_add[i].nextstate]++;
|
||
|
fst_->AddArc(s, arcs_to_add[i]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
|
||
|
// Pattern 2 is where "nextstate" has only one arc out, counting
|
||
|
// being-the-final-state as an arc, but possibly multiple arcs in.
|
||
|
// Also, nextstate != s.
|
||
|
|
||
|
const StateId nextstate = arc.nextstate;
|
||
|
bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
|
||
|
// we combine, can delete the corresponding out-arc/final-prob
|
||
|
// of nextstate.
|
||
|
bool delete_arc = false; // set to true if this arc to be deleted.
|
||
|
|
||
|
Weight next_final = fst_->Final(arc.nextstate);
|
||
|
if (next_final !=
|
||
|
Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
|
||
|
Weight new_final;
|
||
|
if (CanCombineFinal(arc, next_final, &new_final)) {
|
||
|
if (fst_->Final(s) == Weight::Zero())
|
||
|
num_arcs_out_[s]++; // final is counted as arc.
|
||
|
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
|
||
|
delete_arc = true; // will delete "arc".
|
||
|
if (can_delete_next) {
|
||
|
num_arcs_out_[nextstate]--;
|
||
|
fst_->SetFinal(nextstate, Weight::Zero());
|
||
|
}
|
||
|
}
|
||
|
} else { // has an arc but no final prob.
|
||
|
MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
|
||
|
assert(!aiter_next.Done());
|
||
|
while (aiter_next.Value().nextstate == non_coacc_state_) {
|
||
|
aiter_next.Next();
|
||
|
assert(!aiter_next.Done());
|
||
|
}
|
||
|
// now aiter_next points to a real arc out of nextstate.
|
||
|
Arc nextarc = aiter_next.Value();
|
||
|
Arc combined;
|
||
|
if (CanCombineArcs(arc, nextarc, &combined)) {
|
||
|
delete_arc = true;
|
||
|
if (can_delete_next) { // do it before we invalidate iterators
|
||
|
num_arcs_out_[nextstate]--;
|
||
|
num_arcs_in_[nextarc.nextstate]--;
|
||
|
nextarc.nextstate = non_coacc_state_;
|
||
|
aiter_next.SetValue(nextarc);
|
||
|
}
|
||
|
num_arcs_out_[s]++;
|
||
|
num_arcs_in_[combined.nextstate]++;
|
||
|
fst_->AddArc(s, combined);
|
||
|
}
|
||
|
}
|
||
|
if (delete_arc) {
|
||
|
num_arcs_out_[s]--;
|
||
|
num_arcs_in_[nextstate]--;
|
||
|
arc.nextstate = non_coacc_state_;
|
||
|
SetArc(s, pos, arc);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void RemoveEps(StateId s, size_t pos) {
|
||
|
// Tries to do local epsilon-removal for arc sequences starting with this
|
||
|
// arc
|
||
|
Arc arc;
|
||
|
GetArc(s, pos, &arc);
|
||
|
StateId nextstate = arc.nextstate;
|
||
|
if (nextstate == non_coacc_state_) return; // deleted arc.
|
||
|
if (nextstate == s) return; // don't handle self-loops: too complex.
|
||
|
|
||
|
if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
|
||
|
RemoveEpsPattern1(s, pos, arc);
|
||
|
} else if (num_arcs_out_[nextstate] == 1) {
|
||
|
RemoveEpsPattern2(s, pos, arc);
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
|
||
|
template <class Arc>
|
||
|
void RemoveEpsLocal(MutableFst<Arc> *fst) {
|
||
|
RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
|
||
|
}
|
||
|
|
||
|
void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
|
||
|
// work gets done in initializer.
|
||
|
RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
|
||
|
}
|
||
|
|
||
|
} // end namespace fst.
|
||
|
|
||
|
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|