You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/kaldi/fstext/remove-eps-local-inl.h

319 lines
12 KiB

3 years ago
// fstext/remove-eps-local-inl.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#include <vector>
namespace fst {
template <class Weight>
struct ReweightPlusDefault {
inline Weight operator()(const Weight &a, const Weight &b) {
return Plus(a, b);
}
};
struct ReweightPlusLogArc {
inline TropicalWeight operator()(const TropicalWeight &a,
const TropicalWeight &b) {
LogWeight a_log(a.Value()), b_log(b.Value());
return TropicalWeight(Plus(a_log, b_log).Value());
}
};
template <class Arc,
class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
class RemoveEpsLocalClass {
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
public:
explicit RemoveEpsLocalClass(MutableFst<Arc> *fst) : fst_(fst) {
if (fst_->Start() == kNoStateId) return; // empty.
non_coacc_state_ = fst_->AddState();
InitNumArcs();
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++)
for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
assert(CheckNumArcs());
Connect(fst); // remove inaccessible states.
}
private:
MutableFst<Arc> *fst_;
StateId non_coacc_state_; // use this to delete arcs: make it nextstate
std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
// one if it's the start state.
std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
// plus one if it's a final state.
ReweightPlus reweight_plus_;
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
if (a.ilabel != 0 && b.ilabel != 0) return false;
if (a.olabel != 0 && b.olabel != 0) return false;
c->weight = Times(a.weight, b.weight);
c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
c->nextstate = b.nextstate;
return true;
}
static bool CanCombineFinal(const Arc &a, Weight final_prob,
Weight *final_prob_out) {
if (a.ilabel != 0 || a.olabel != 0) {
return false;
} else {
*final_prob_out = Times(a.weight, final_prob);
return true;
}
}
void InitNumArcs() { // init num transitions in/out of each state.
StateId num_states = fst_->NumStates();
num_arcs_in_.resize(num_states);
num_arcs_out_.resize(num_states);
num_arcs_in_[fst_->Start()]++; // count start as trans in.
for (StateId s = 0; s < num_states; s++) {
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]++; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
num_arcs_in_[aiter.Value().nextstate]++;
num_arcs_out_[s]++;
}
}
}
bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
num_arcs_in_[fst_->Start()]--; // count start as trans in.
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++) {
if (s == non_coacc_state_) continue;
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]--; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
if (aiter.Value().nextstate == non_coacc_state_) continue;
num_arcs_in_[aiter.Value().nextstate]--;
num_arcs_out_[s]--;
}
}
for (StateId s = 0; s < num_states; s++) {
assert(num_arcs_in_[s] == 0);
assert(num_arcs_out_[s] == 0);
}
return true; // always does this. so we can assert it w/o warnings.
}
inline void GetArc(StateId s, size_t pos, Arc *arc) const {
ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
aiter.Seek(pos);
*arc = aiter.Value();
}
inline void SetArc(StateId s, size_t pos, const Arc &arc) {
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
aiter.SetValue(arc);
}
void Reweight(StateId s, size_t pos, Weight reweight) {
// Reweight is called from RemoveEpsPattern1; it is a step we
// do to preserve stochasticity. This function multiplies the
// arc at (s, pos) by reweight and divides all the arcs [+final-prob]
// out of the next state by the same. This is only valid if
// the next state has only one arc in and is not the start state.
assert(reweight != Weight::Zero());
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
Arc arc = aiter.Value();
assert(num_arcs_in_[arc.nextstate] == 1);
arc.weight = Times(arc.weight, reweight);
aiter.SetValue(arc);
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate != non_coacc_state_) {
nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
aiter_next.SetValue(nextarc);
}
}
Weight final = fst_->Final(arc.nextstate);
if (final != Weight::Zero()) {
fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
}
}
// RemoveEpsPattern1 applies where this arc, which is not a
// self-loop, enters a state which has only one input transition
// [and is not the start state], and has multiple output
// transitions [counting being the final-state as a final-transition].
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
const StateId nextstate = arc.nextstate;
Weight total_removed = Weight::Zero(),
total_kept = Weight::Zero(); // totals out of nextstate.
std::vector<Arc> arcs_to_add; // to add to state s.
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
total_removed = reweight_plus_(total_removed, nextarc.weight);
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
arcs_to_add.push_back(combined);
} else {
total_kept = reweight_plus_(total_kept, nextarc.weight);
}
}
{ // now final-state.
Weight next_final = fst_->Final(nextstate);
if (next_final != Weight::Zero()) {
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
total_removed = reweight_plus_(total_removed, next_final);
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
} else {
total_kept = reweight_plus_(total_kept, next_final);
}
}
}
if (total_removed != Weight::Zero()) { // did something...
if (total_kept == Weight::Zero()) { // removed everything: remove arc.
num_arcs_out_[s]--;
num_arcs_in_[arc.nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
} else {
// Have to reweight.
Weight total = reweight_plus_(total_removed, total_kept);
Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
Reweight(s, pos, reweight);
}
}
// Now add the arcs we were going to add.
for (size_t i = 0; i < arcs_to_add.size(); i++) {
num_arcs_out_[s]++;
num_arcs_in_[arcs_to_add[i].nextstate]++;
fst_->AddArc(s, arcs_to_add[i]);
}
}
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
// Pattern 2 is where "nextstate" has only one arc out, counting
// being-the-final-state as an arc, but possibly multiple arcs in.
// Also, nextstate != s.
const StateId nextstate = arc.nextstate;
bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
// we combine, can delete the corresponding out-arc/final-prob
// of nextstate.
bool delete_arc = false; // set to true if this arc to be deleted.
Weight next_final = fst_->Final(arc.nextstate);
if (next_final !=
Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
delete_arc = true; // will delete "arc".
if (can_delete_next) {
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
}
}
} else { // has an arc but no final prob.
MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
assert(!aiter_next.Done());
while (aiter_next.Value().nextstate == non_coacc_state_) {
aiter_next.Next();
assert(!aiter_next.Done());
}
// now aiter_next points to a real arc out of nextstate.
Arc nextarc = aiter_next.Value();
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
delete_arc = true;
if (can_delete_next) { // do it before we invalidate iterators
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
}
num_arcs_out_[s]++;
num_arcs_in_[combined.nextstate]++;
fst_->AddArc(s, combined);
}
}
if (delete_arc) {
num_arcs_out_[s]--;
num_arcs_in_[nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
}
}
void RemoveEps(StateId s, size_t pos) {
// Tries to do local epsilon-removal for arc sequences starting with this
// arc
Arc arc;
GetArc(s, pos, &arc);
StateId nextstate = arc.nextstate;
if (nextstate == non_coacc_state_) return; // deleted arc.
if (nextstate == s) return; // don't handle self-loops: too complex.
if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
RemoveEpsPattern1(s, pos, arc);
} else if (num_arcs_out_[nextstate] == 1) {
RemoveEpsPattern2(s, pos, arc);
}
}
};
template <class Arc>
void RemoveEpsLocal(MutableFst<Arc> *fst) {
RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
}
void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
// work gets done in initializer.
RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
}
} // end namespace fst.
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_