You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/kaldi/lat/lattice-functions.cc

1993 lines
79 KiB

// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
// Bagher BabaAli
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/lattice-functions.h"
// #include "hmm/transition-model.h"
// #include "util/stl-utils.h"
#include "base/kaldi-math.h"
// #include "hmm/hmm-utils.h"
namespace kaldi {
using std::map;
using std::vector;
// void GetPerFrameAcousticCosts(const Lattice &nbest,
// Vector<BaseFloat> *per_frame_loglikes) {
// using namespace fst;
// typedef Lattice::Arc::Weight Weight;
// vector<BaseFloat> loglikes;
//
// int32 cur_state = nbest.Start();
// int32 prev_frame = -1;
// BaseFloat eps_acwt = 0.0;
// while(1) {
// Weight w = nbest.Final(cur_state);
// if (w != Weight::Zero()) {
// KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
// if (per_frame_loglikes != NULL) {
// SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
// Vector<BaseFloat> vec(subvec);
// *per_frame_loglikes = vec;
// }
// break;
// } else {
// KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
// fst::ArcIterator<Lattice> iter(nbest, cur_state);
// const Lattice::Arc &arc = iter.Value();
// BaseFloat acwt = arc.weight.Value2();
// if (arc.ilabel != 0) {
// if (eps_acwt > 0) {
// acwt += eps_acwt;
// eps_acwt = 0.0;
// }
// loglikes.push_back(acwt);
// prev_frame++;
// } else if (acwt == acwt){
// if (prev_frame > -1) {
// loglikes[prev_frame] += acwt;
// } else {
// eps_acwt += acwt;
// }
// }
// cur_state = arc.nextstate;
// }
// }
// }
//
// int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
// if (!lat.Properties(fst::kTopSorted, true))
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
// int32 num_states = lat.NumStates();
// times->clear();
// times->resize(num_states, -1);
// (*times)[0] = 0;
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = (*times)[state];
// for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const LatticeArc &arc = aiter.Value();
//
// if (arc.ilabel != 0) { // Non-epsilon input label on arc
// // next time instance
// if ((*times)[arc.nextstate] == -1) {
// (*times)[arc.nextstate] = cur_time + 1;
// } else {
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
// }
// } else { // epsilon input label on arc
// // Same time instance
// if ((*times)[arc.nextstate] == -1)
// (*times)[arc.nextstate] = cur_time;
// else
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
// }
// }
// }
// return (*std::max_element(times->begin(), times->end()));
// }
//
// int32 CompactLatticeStateTimes(const CompactLattice &lat,
// vector<int32> *times) {
// if (!lat.Properties(fst::kTopSorted, true))
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
// int32 num_states = lat.NumStates();
// times->clear();
// times->resize(num_states, -1);
// (*times)[0] = 0;
// int32 utt_len = -1;
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = (*times)[state];
// for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// int32 arc_len = static_cast<int32>(arc.weight.String().size());
// if ((*times)[arc.nextstate] == -1)
// (*times)[arc.nextstate] = cur_time + arc_len;
// else
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
// }
// if (lat.Final(state) != CompactLatticeWeight::Zero()) {
// int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
// if (utt_len == -1) utt_len = this_utt_len;
// else {
// if (this_utt_len != utt_len) {
// KALDI_WARN << "Utterance does not "
// "seem to have a consistent length.";
// utt_len = std::max(utt_len, this_utt_len);
// }
// }
// }
// }
// if (utt_len == -1) {
// KALDI_WARN << "Utterance does not have a final-state.";
// return 0;
// }
// return utt_len;
// }
//
// bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
// vector<double> *alpha) {
// using namespace fst;
//
// // typedef the arc, weight types
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// //Make sure the lattice is topologically sorted.
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_WARN << "Input lattice must be topologically sorted.";
// return false;
// }
// if (clat.Start() != 0) {
// KALDI_WARN << "Input lattice must start from state 0.";
// return false;
// }
//
// int32 num_states = clat.NumStates();
// (*alpha).resize(0);
// (*alpha).resize(num_states, kLogZeroDouble);
//
// // Now propagate alphas forward. Note that we don't acount the weight of the
// // final state to alpha[final_state] -- we acount it to beta[final_state];
// (*alpha)[0] = 0.0;
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = (*alpha)[s];
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -(arc.weight.Weight().Value1() +
// arc.weight.Weight().Value2());
// (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate],
// this_alpha + arc_like);
// }
// }
//
// return true;
// }
//
// bool ComputeCompactLatticeBetas(const CompactLattice &clat,
// vector<double> *beta) {
// using namespace fst;
//
// // typedef the arc, weight types
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// // Make sure the lattice is topologically sorted.
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_WARN << "Input lattice must be topologically sorted.";
// return false;
// }
// if (clat.Start() != 0) {
// KALDI_WARN << "Input lattice must start from state 0.";
// return false;
// }
//
// int32 num_states = clat.NumStates();
// (*beta).resize(0);
// (*beta).resize(num_states, kLogZeroDouble);
//
// // Now propagate betas backward. Note that beta[final_state] contains the
// // weight of the final state in the lattice -- compare that with alpha.
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = clat.Final(s);
// double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -(arc.weight.Weight().Value1() +
// arc.weight.Weight().Value2());
// double arc_beta = (*beta)[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// }
// (*beta)[s] = this_beta;
// }
//
// return true;
// }
template<class LatType> // could be Lattice or CompactLattice
bool PruneLattice(BaseFloat beam, LatType *lat) {
typedef typename LatType::Arc Arc;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
KALDI_ASSERT(beam > 0.0);
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice";
return false;
}
}
// We assume states before "start" are not reachable, since
// the lattice is topologically sorted.
int32 start = lat->Start();
int32 num_states = lat->NumStates();
if (num_states == 0) return false;
std::vector<double> forward_cost(num_states,
std::numeric_limits<double>::infinity()); // viterbi forward.
forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
// less than this.
double best_final_cost = std::numeric_limits<double>::infinity();
// Update the forward probs.
// Thanks to Jing Zheng for finding a bug here.
for (int32 state = 0; state < num_states; state++) {
double this_forward_cost = forward_cost[state];
for (fst::ArcIterator<LatType> aiter(*lat, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double next_forward_cost = this_forward_cost +
ConvertToCost(arc.weight);
if (forward_cost[nextstate] > next_forward_cost)
forward_cost[nextstate] = next_forward_cost;
}
Weight final_weight = lat->Final(state);
double this_final_cost = this_forward_cost +
ConvertToCost(final_weight);
if (this_final_cost < best_final_cost)
best_final_cost = this_final_cost;
}
int32 bad_state = lat->AddState(); // this state is not final.
double cutoff = best_final_cost + beam;
// Go backwards updating the backward probs (which share memory with the
// forward probs), and pruning arcs and deleting final-probs. We prune arcs
// by making them point to the non-final state "bad_state". We'll then use
// Trim() to remove unnecessary arcs and states. [this is just easier than
// doing it ourselves.]
std::vector<double> &backward_cost(forward_cost);
for (int32 state = num_states - 1; state >= 0; state--) {
double this_forward_cost = forward_cost[state];
double this_backward_cost = ConvertToCost(lat->Final(state));
if (this_backward_cost + this_forward_cost > cutoff
&& this_backward_cost != std::numeric_limits<double>::infinity())
lat->SetFinal(state, Weight::Zero());
for (fst::MutableArcIterator<LatType> aiter(lat, state);
!aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double arc_cost = ConvertToCost(arc.weight),
arc_backward_cost = arc_cost + backward_cost[nextstate],
this_fb_cost = this_forward_cost + arc_backward_cost;
if (arc_backward_cost < this_backward_cost)
this_backward_cost = arc_backward_cost;
if (this_fb_cost > cutoff) { // Prune the arc.
arc.nextstate = bad_state;
aiter.SetValue(arc);
}
}
backward_cost[state] = this_backward_cost;
}
fst::Connect(lat);
return (lat->NumStates() > 0);
}
// instantiate the template for lattice and CompactLattice.
template bool PruneLattice(BaseFloat beam, Lattice *lat);
template bool PruneLattice(BaseFloat beam, CompactLattice *lat);
// BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
// double *acoustic_like_sum) {
// // Note, Posterior is defined as follows: Indexed [frame], then a list
// // of (transition-id, posterior-probability) pairs.
// // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
// using namespace fst;
// typedef Lattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// if (acoustic_like_sum) *acoustic_like_sum = 0.0;
//
// // Make sure the lattice is topologically sorted.
// if (lat.Properties(fst::kTopSorted, true) == 0)
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
//
// int32 num_states = lat.NumStates();
// vector<int32> state_times;
// int32 max_time = LatticeStateTimes(lat, &state_times);
// std::vector<double> alpha(num_states, kLogZeroDouble);
// std::vector<double> &beta(alpha); // we re-use the same memory for
// // this, but it's semantically distinct so we name it differently.
// double tot_forward_prob = kLogZeroDouble;
//
// post->clear();
// post->resize(max_time);
//
// alpha[0] = 0.0;
// // Propagate alphas forward.
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// tot_forward_prob = LogAdd(tot_forward_prob, final_like);
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = lat.Final(s);
// double this_beta = -(f.Value1() + f.Value2());
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// int32 transition_id = arc.ilabel;
//
// // The following "if" is an optimization to avoid un-needed exp().
// if (transition_id != 0 || acoustic_like_sum != NULL) {
// double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
//
// if (transition_id != 0) // Arc has a transition-id on it [not epsilon]
// (*post)[state_times[s]].push_back(std::make_pair(transition_id,
// static_cast<kaldi::BaseFloat>(posterior)));
// if (acoustic_like_sum != NULL)
// *acoustic_like_sum -= posterior * arc.weight.Value2();
// }
// }
// if (acoustic_like_sum != NULL && f != Weight::Zero()) {
// double final_logprob = - ConvertToCost(f),
// posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
// *acoustic_like_sum -= posterior * f.Value2();
// }
// beta[s] = this_beta;
// }
// double tot_backward_prob = beta[0];
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
// KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
// // Now combine any posteriors with the same transition-id.
// for (int32 t = 0; t < max_time; t++)
// MergePairVectorSumming(&((*post)[t]));
// return tot_backward_prob;
// }
//
//
// void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
// const vector<int32> &silence_phones,
// vector< std::set<int32> > *active_phones) {
// KALDI_ASSERT(IsSortedAndUniq(silence_phones));
// vector<int32> state_times;
// int32 num_states = lat.NumStates();
// int32 max_time = LatticeStateTimes(lat, &state_times);
// active_phones->clear();
// active_phones->resize(max_time);
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = state_times[state];
// for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const LatticeArc &arc = aiter.Value();
// if (arc.ilabel != 0) { // Non-epsilon arc
// int32 phone = trans.TransitionIdToPhone(arc.ilabel);
// if (!std::binary_search(silence_phones.begin(),
// silence_phones.end(), phone))
// (*active_phones)[cur_time].insert(phone);
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// void ConvertLatticeToPhones(const TransitionModel &trans,
// Lattice *lat) {
// typedef LatticeArc Arc;
// int32 num_states = lat->NumStates();
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// Arc arc(aiter.Value());
// arc.olabel = 0; // remove any word.
// if ((arc.ilabel != 0) // has a transition-id on input..
// && (trans.TransitionIdToHmmState(arc.ilabel) == 0)
// && (!trans.IsSelfLoop(arc.ilabel))) {
// // && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
// arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
// }
// aiter.SetValue(arc);
// } // end looping over arcs
// } // end looping over states
// }
//
//
// static inline double LogAddOrMax(bool viterbi, double a, double b) {
// if (viterbi)
// return std::max(a, b);
// else
// return LogAdd(a, b);
// }
//
// template<typename LatticeType>
// double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta) {
// typedef typename LatticeType::Arc Arc;
// typedef typename Arc::Weight Weight;
// typedef typename Arc::StateId StateId;
//
// StateId num_states = lat.NumStates();
// KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
// KALDI_ASSERT(lat.Start() == 0);
// alpha->clear();
// beta->clear();
// alpha->resize(num_states, kLogZeroDouble);
// beta->resize(num_states, kLogZeroDouble);
//
// double tot_forward_prob = kLogZeroDouble;
// (*alpha)[0] = 0.0;
// // Propagate alphas forward.
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = (*alpha)[s];
// for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
// this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - ConvertToCost(f);
// tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
// }
// }
// for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
// double this_beta = -ConvertToCost(lat.Final(s));
// for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = (*beta)[arc.nextstate] + arc_like;
// this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
// }
// (*beta)[s] = this_beta;
// }
// double tot_backward_prob = (*beta)[lat.Start()];
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
// KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
// // Split the difference when returning... they should be the same.
// return 0.5 * (tot_backward_prob + tot_forward_prob);
// }
//
// // instantiate the template for Lattice and CompactLattice
// template
// double ComputeLatticeAlphasAndBetas(const Lattice &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta);
//
// template
// double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta);
//
//
//
// /// This is used in CompactLatticeLimitDepth.
// struct LatticeArcRecord {
// BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
// // minus the overall best-cost of the lattice.
// CompactLatticeArc::StateId state; // state in the lattice.
// size_t arc; // arc index within the state.
// bool operator < (const LatticeArcRecord &other) const {
// return logprob < other.logprob;
// }
// };
//
// void CompactLatticeLimitDepth(int32 max_depth_per_frame,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// if (clat->Start() == fst::kNoStateId) {
// KALDI_WARN << "Limiting depth of empty lattice.";
// return;
// }
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// if (!TopSort(clat))
// KALDI_ERR << "Topological sorting of lattice failed.";
// }
//
// vector<int32> state_times;
// int32 T = CompactLatticeStateTimes(*clat, &state_times);
//
// // The alpha and beta quantities here are "viterbi" alphas and beta.
// std::vector<double> alpha;
// std::vector<double> beta;
// bool viterbi = true;
// double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
// &alpha, &beta);
//
// std::vector<std::vector<LatticeArcRecord> > arc_records(T);
//
// StateId num_states = clat->NumStates();
// for (StateId s = 0; s < num_states; s++) {
// for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// LatticeArcRecord arc_record;
// arc_record.state = s;
// arc_record.arc = aiter.Position();
// arc_record.logprob =
// (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
// - best_prob;
// KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
// int32 num_frames = arc.weight.String().size(), start_t = state_times[s];
// for (int32 t = start_t; t < start_t + num_frames; t++) {
// KALDI_ASSERT(t < T);
// arc_records[t].push_back(arc_record);
// }
// }
// }
// StateId dead_state = clat->AddState(); // A non-coaccesible state which we use
// // to remove arcs (make them end
// // there).
// size_t max_depth = max_depth_per_frame;
// for (int32 t = 0; t < T; t++) {
// size_t size = arc_records[t].size();
// if (size > max_depth) {
// // we sort from worst to best, so we keep the later-numbered ones,
// // and delete the lower-numbered ones.
// size_t cutoff = size - max_depth;
// std::nth_element(arc_records[t].begin(),
// arc_records[t].begin() + cutoff,
// arc_records[t].end());
// for (size_t index = 0; index < cutoff; index++) {
// LatticeArcRecord record(arc_records[t][index]);
// fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
// aiter.Seek(record.arc);
// Arc arc = aiter.Value();
// if (arc.nextstate != dead_state) { // not already killed.
// arc.nextstate = dead_state;
// aiter.SetValue(arc);
// }
// }
// }
// }
// Connect(clat);
// TopSortCompactLatticeIfNeeded(clat);
// }
//
//
// void TopSortCompactLatticeIfNeeded(CompactLattice *clat) {
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// if (fst::TopSort(clat) == false) {
// KALDI_ERR << "Topological sorting failed";
// }
// }
// }
//
// void TopSortLatticeIfNeeded(Lattice *lat) {
// if (lat->Properties(fst::kTopSorted, true) == 0) {
// if (fst::TopSort(lat) == false) {
// KALDI_ERR << "Topological sorting failed";
// }
// }
// }
//
//
// /// Returns the depth of the lattice, defined as the average number of
// /// arcs crossing any given frame. Returns 1 for empty lattices.
// /// Requires that input is topologically sorted.
// BaseFloat CompactLatticeDepth(const CompactLattice &clat,
// int32 *num_frames) {
// typedef CompactLattice::Arc::StateId StateId;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically "
// << "sorted.";
// }
// if (clat.Start() == fst::kNoStateId) {
// *num_frames = 0;
// return 1.0;
// }
// size_t num_arc_frames = 0;
// int32 t;
// {
// vector<int32> state_times;
// t = CompactLatticeStateTimes(clat, &state_times);
// }
// if (num_frames != NULL)
// *num_frames = t;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// num_arc_frames += arc.weight.String().size();
// }
// num_arc_frames += clat.Final(s).String().size();
// }
// return num_arc_frames / static_cast<BaseFloat>(t);
// }
//
//
// void CompactLatticeDepthPerFrame(const CompactLattice &clat,
// std::vector<int32> *depth_per_frame) {
// typedef CompactLattice::Arc::StateId StateId;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
// << "topologically sorted.";
// }
// if (clat.Start() == fst::kNoStateId) {
// depth_per_frame->clear();
// return;
// }
// vector<int32> state_times;
// int32 T = CompactLatticeStateTimes(clat, &state_times);
//
// depth_per_frame->clear();
// if (T <= 0) {
// return;
// } else {
// depth_per_frame->resize(T, 0);
// for (StateId s = 0; s < clat.NumStates(); s++) {
// int32 start_time = state_times[s];
// for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// int32 len = arc.weight.String().size();
// for (int32 t = start_time; t < start_time + len; t++) {
// KALDI_ASSERT(t < T);
// (*depth_per_frame)[t]++;
// }
// }
// int32 final_len = clat.Final(s).String().size();
// for (int32 t = start_time; t < start_time + final_len; t++) {
// KALDI_ASSERT(t < T);
// (*depth_per_frame)[t]++;
// }
// }
// }
// }
//
//
//
// void ConvertCompactLatticeToPhones(const TransitionModel &trans,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// int32 num_states = clat->NumStates();
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done();
// aiter.Next()) {
// Arc arc(aiter.Value());
// std::vector<int32> phone_seq;
// const std::vector<int32> &tid_seq = arc.weight.String();
// for (std::vector<int32>::const_iterator iter = tid_seq.begin();
// iter != tid_seq.end(); ++iter) {
// if (trans.IsFinal(*iter))// note: there is one of these per phone...
// phone_seq.push_back(trans.TransitionIdToPhone(*iter));
// }
// arc.weight.SetString(phone_seq);
// aiter.SetValue(arc);
// } // end looping over arcs
// Weight f = clat->Final(state);
// if (f != Weight::Zero()) {
// std::vector<int32> phone_seq;
// const std::vector<int32> &tid_seq = f.String();
// for (std::vector<int32>::const_iterator iter = tid_seq.begin();
// iter != tid_seq.end(); ++iter) {
// if (trans.IsFinal(*iter))// note: there is one of these per phone...
// phone_seq.push_back(trans.TransitionIdToPhone(*iter));
// }
// f.SetString(phone_seq);
// clat->SetFinal(state, f);
// }
// } // end looping over states
// }
//
// bool LatticeBoost(const TransitionModel &trans,
// const std::vector<int32> &alignment,
// const std::vector<int32> &silence_phones,
// BaseFloat b,
// BaseFloat max_silence_error,
// Lattice *lat) {
// TopSortLatticeIfNeeded(lat);
//
// // get all stored properties (test==false means don't test if not known).
// uint64 props = lat->Properties(fst::kFstProperties,
// false);
//
// KALDI_ASSERT(IsSortedAndUniq(silence_phones));
// KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
// vector<int32> state_times;
// int32 num_states = lat->NumStates();
// int32 num_frames = LatticeStateTimes(*lat, &state_times);
// KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = state_times[state];
// for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// LatticeArc arc = aiter.Value();
// if (arc.ilabel != 0) { // Non-epsilon arc
// if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
// KALDI_WARN << "Lattice has out-of-range transition-ids: "
// << "lattice/model mismatch?";
// return false;
// }
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
// BaseFloat frame_error;
// if (phone == ref_phone) {
// frame_error = 0.0;
// } else { // an error...
// if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
// frame_error = max_silence_error;
// else
// frame_error = 1.0;
// }
// BaseFloat delta_cost = -b * frame_error; // negative cost if
// // frame is wrong, to boost likelihood of arcs with errors on them.
// // Add this cost to the graph part.
// arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
// aiter.SetValue(arc);
// }
// }
// }
// // All we changed is the weights, so any properties that were
// // known before, are still known, except for whether or not the
// // lattice was weighted.
// lat->SetProperties(props,
// ~(fst::kWeighted|fst::kUnweighted));
//
// return true;
// }
//
//
//
// BaseFloat LatticeForwardBackwardMpeVariants(
// const TransitionModel &trans,
// const std::vector<int32> &silence_phones,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// std::string criterion,
// bool one_silence_class,
// Posterior *post) {
// using namespace fst;
// typedef Lattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
// bool is_mpfe = (criterion == "mpfe");
//
// if (lat.Properties(fst::kTopSorted, true) == 0)
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
//
// int32 num_states = lat.NumStates();
// vector<int32> state_times;
// int32 max_time = LatticeStateTimes(lat, &state_times);
// KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
// std::vector<double> alpha(num_states, kLogZeroDouble),
// alpha_smbr(num_states, 0), //forward variable for sMBR
// beta(num_states, kLogZeroDouble),
// beta_smbr(num_states, 0); //backward variable for sMBR
//
// double tot_forward_prob = kLogZeroDouble;
// double tot_forward_score = 0;
//
// post->clear();
// post->resize(max_time);
//
// alpha[0] = 0.0;
// // First Pass Forward,
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// tot_forward_prob = LogAdd(tot_forward_prob, final_like);
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// // First Pass Backward,
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = lat.Final(s);
// double this_beta = -(f.Value1() + f.Value2());
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// }
// beta[s] = this_beta;
// }
// // First Pass Forward-Backward Check
// double tot_backward_prob = beta[0];
// // may loose the condition somehow here 1e-6 (was 1e-8)
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
// KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
//
// alpha_smbr[0] = 0.0;
// // Second Pass Forward, calculate forward for MPFE/SMBR
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// double frame_acc = 0.0;
// if (arc.ilabel != 0) {
// int32 cur_time = state_times[s];
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
// bool phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// phone),
// ref_phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// ref_phone),
// both_sil = phone_is_sil && ref_phone_is_sil;
// if (!is_mpfe) { // smbr.
// int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
// ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
// if (!one_silence_class) // old behavior
// frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
// } else {
// if (!one_silence_class) // old behavior
// frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
// }
// }
// double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
// alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// double arc_scale = Exp(final_like - tot_forward_prob);
// tot_forward_score += arc_scale * alpha_smbr[s];
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// // Second Pass Backward, collect Mpe style posteriors
// for (StateId s = num_states-1; s >= 0; s--) {
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// double frame_acc = 0.0;
// int32 transition_id = arc.ilabel;
// if (arc.ilabel != 0) {
// int32 cur_time = state_times[s];
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
// bool phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(), phone),
// ref_phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// ref_phone),
// both_sil = phone_is_sil && ref_phone_is_sil;
// if (!is_mpfe) { // smbr.
// int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
// ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
// if (!one_silence_class) // old behavior
// frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
// } else {
// if (!one_silence_class) // old behavior
// frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
// }
// }
// double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
// // check arc_scale NAN,
// // this is to prevent partial paths in Lattices
// // i.e., paths don't survive to the final state
// if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
// beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
//
// if (transition_id != 0) { // Arc has a transition-id on it [not epsilon]
// double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
// double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate]
// - tot_forward_score;
// double posterior_smbr = posterior * acc_diff;
// (*post)[state_times[s]].push_back(std::make_pair(transition_id,
// static_cast<BaseFloat>(posterior_smbr)));
// }
// }
// }
//
// //Second Pass Forward Backward check
// double tot_backward_score = beta_smbr[0]; // Initial state id == 0
// // may loose the condition somehow here 1e-5/1e-4
// if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
// KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
// << ", while total backward score = " << tot_backward_score;
// }
//
// // Output the computed posteriors
// for (int32 t = 0; t < max_time; t++)
// MergePairVectorSumming(&((*post)[t]));
// return tot_forward_score;
// }
//
// bool CompactLatticeToWordAlignment(const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths) {
// words->clear();
// begin_times->clear();
// lengths->clear();
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef CompactLattice::StateId StateId;
// typedef CompactLattice::Weight Weight;
// using namespace fst;
// StateId state = clat.Start();
// int32 cur_time = 0;
// if (state == kNoStateId) {
// KALDI_WARN << "Empty lattice.";
// return false;
// }
// while (1) {
// Weight final = clat.Final(state);
// size_t num_arcs = clat.NumArcs(state);
// if (final != Weight::Zero()) {
// if (num_arcs != 0) {
// KALDI_WARN << "Lattice is not linear.";
// return false;
// }
// if (! final.String().empty()) {
// KALDI_WARN << "Lattice has alignments on final-weight: probably "
// "was not word-aligned (alignments will be approximate)";
// }
// return true;
// } else {
// if (num_arcs != 1) {
// KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
// return false;
// }
// fst::ArcIterator<CompactLattice> aiter(clat, state);
// const Arc &arc = aiter.Value();
// Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// // Also note: word_id may be zero; we output it anyway.
// int32 length = arc.weight.String().size();
// words->push_back(word_id);
// begin_times->push_back(cur_time);
// lengths->push_back(length);
// cur_time += length;
// state = arc.nextstate;
// }
// }
// }
//
//
// bool CompactLatticeToWordProns(
// const TransitionModel &tmodel,
// const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths,
// std::vector<std::vector<int32> > *prons,
// std::vector<std::vector<int32> > *phone_lengths) {
// words->clear();
// begin_times->clear();
// lengths->clear();
// prons->clear();
// phone_lengths->clear();
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef CompactLattice::StateId StateId;
// typedef CompactLattice::Weight Weight;
// using namespace fst;
// StateId state = clat.Start();
// int32 cur_time = 0;
// if (state == kNoStateId) {
// KALDI_WARN << "Empty lattice.";
// return false;
// }
// while (1) {
// Weight final = clat.Final(state);
// size_t num_arcs = clat.NumArcs(state);
// if (final != Weight::Zero()) {
// if (num_arcs != 0) {
// KALDI_WARN << "Lattice is not linear.";
// return false;
// }
// if (! final.String().empty()) {
// KALDI_WARN << "Lattice has alignments on final-weight: probably "
// "was not word-aligned (alignments will be approximate)";
// }
// return true;
// } else {
// if (num_arcs != 1) {
// KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
// return false;
// }
// fst::ArcIterator<CompactLattice> aiter(clat, state);
// const Arc &arc = aiter.Value();
// Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// // Also note: word_id may be zero; we output it anyway.
// int32 length = arc.weight.String().size();
// words->push_back(word_id);
// begin_times->push_back(cur_time);
// lengths->push_back(length);
// const std::vector<int32> &arc_alignment = arc.weight.String();
// std::vector<std::vector<int32> > split_alignment;
// SplitToPhones(tmodel, arc_alignment, &split_alignment);
// std::vector<int32> phones(split_alignment.size());
// std::vector<int32> plengths(split_alignment.size());
// for (size_t i = 0; i < split_alignment.size(); i++) {
// KALDI_ASSERT(!split_alignment[i].empty());
// phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
// plengths[i] = split_alignment[i].size();
// }
// prons->push_back(phones);
// phone_lengths->push_back(plengths);
//
// cur_time += length;
// state = arc.nextstate;
// }
// }
// }
//
//
//
// void CompactLatticeShortestPath(const CompactLattice &clat,
// CompactLattice *shortest_path) {
// using namespace fst;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// CompactLatticeShortestPath(clat_copy, shortest_path);
// return;
// }
// // Now we can assume it's topologically sorted.
// shortest_path->DeleteStates();
// if (clat.Start() == kNoStateId) return;
// typedef CompactLatticeArc Arc;
// typedef Arc::StateId StateId;
// typedef CompactLatticeWeight Weight;
// vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1);
// StateId superfinal = clat.NumStates();
// for (StateId s = 0; s <= clat.NumStates(); s++) {
// best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
// best_cost_and_pred[s].second = fst::kNoStateId;
// }
// best_cost_and_pred[clat.Start()].first = 0;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// double my_cost = best_cost_and_pred[s].first;
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_cost = ConvertToCost(arc.weight),
// next_cost = my_cost + arc_cost;
// if (next_cost < best_cost_and_pred[arc.nextstate].first) {
// best_cost_and_pred[arc.nextstate].first = next_cost;
// best_cost_and_pred[arc.nextstate].second = s;
// }
// }
// double final_cost = ConvertToCost(clat.Final(s)),
// tot_final = my_cost + final_cost;
// if (tot_final < best_cost_and_pred[superfinal].first) {
// best_cost_and_pred[superfinal].first = tot_final;
// best_cost_and_pred[superfinal].second = s;
// }
// }
// std::vector<StateId> states; // states on best path.
// StateId cur_state = superfinal, start_state = clat.Start();
// while (cur_state != start_state) {
// StateId prev_state = best_cost_and_pred[cur_state].second;
// if (prev_state == kNoStateId) {
// KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)";
// return; // return empty best-path.
// }
// states.push_back(prev_state);
// KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
// cur_state = prev_state;
// }
// std::reverse(states.begin(), states.end());
// for (size_t i = 0; i < states.size(); i++)
// shortest_path->AddState();
// for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
// if (s == 0) shortest_path->SetStart(s);
// if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state.
// bool have_arc = false;
// Arc cur_arc;
// for (ArcIterator<CompactLattice> aiter(clat, states[s]);
// !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// if (arc.nextstate == states[s+1]) {
// if (!have_arc ||
// ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
// cur_arc = arc;
// have_arc = true;
// }
// }
// }
// KALDI_ASSERT(have_arc && "Code error.");
// shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
// cur_arc.weight, s+1));
// } else { // final-prob.
// shortest_path->SetFinal(s, clat.Final(states[s]));
// }
// }
// }
//
//
// void ExpandCompactLattice(const CompactLattice &clat,
// double epsilon,
// CompactLattice *expand_clat) {
// using namespace fst;
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
// typedef std::pair<StateId, StateId> StatePair;
// typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
// typedef MapType::iterator IterType;
//
// if (clat.Start() == kNoStateId) return;
// // Make sure the input lattice is topologically sorted.
// if (clat.Properties(kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// KALDI_LOG << "Topsort this lattice.";
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// ExpandCompactLattice(clat_copy, epsilon, expand_clat);
// return;
// }
//
// // Compute backward logprobs betas for the expanded lattice.
// // Note: the backward logprobs in the original lattice <clat> and the
// // expanded lattice <expand_clat> are the same.
// int32 num_states = clat.NumStates();
// std::vector<double> beta(num_states, kLogZeroDouble);
// ComputeCompactLatticeBetas(clat, &beta);
// double tot_backward_logprob = beta[0];
// std::vector<double> alpha;
// alpha.push_back(0.0);
// expand_clat->DeleteStates();
// MapType state_map; // Map from state pair (orig_state, copy_state) to
// // copy_state, where orig_state is a state in the original lattice, and
// // copy_state is its corresponding one in the expanded lattice.
// unordered_map<StateId, StateId> states; // Map from orig_state to its
// // copy_state for states with incoming arcs' posteriors <= epsilon.
// std::queue<StatePair> state_queue;
//
// // Set start state in the expanded lattice.
// StateId start_state = expand_clat->AddState();
// expand_clat->SetStart(start_state);
// StatePair start_pair(clat.Start(), start_state);
// state_queue.push(start_pair);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(start_pair, start_state));
// KALDI_ASSERT(result.second == true);
//
// // Expand <clat> and update forward logprobs alphas in <expand_clat>.
// while (!state_queue.empty()) {
// StatePair s = state_queue.front();
// StateId s1 = s.first,
// s2 = s.second;
// state_queue.pop();
//
// Weight f = clat.Final(s1);
// if (f != Weight::Zero()) {
// KALDI_ASSERT(state_map.find(s) != state_map.end());
// expand_clat->SetFinal(state_map[s], f);
// }
//
// for (ArcIterator<CompactLattice> aiter(clat, s1);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// StateId orig_state = arc.nextstate;
// double arc_like = -ConvertToCost(arc.weight),
// this_alpha = alpha[s2] + arc_like,
// arc_post = Exp(this_alpha + beta[orig_state] -
// tot_backward_logprob);
// // Generate the expanded lattice.
// StateId copy_state;
// if (arc_post > epsilon) {
// copy_state = expand_clat->AddState();
// StatePair next_pair(orig_state, copy_state);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(next_pair, copy_state));
// KALDI_ASSERT(result.second == true);
// state_queue.push(next_pair);
// } else {
// unordered_map<StateId, StateId>::iterator iter = states.find(orig_state);
// if (iter == states.end() ) { // The counterpart state of orig_state
// // has not been created in <expand_clat> yet.
// copy_state = expand_clat->AddState();
// StatePair next_pair(orig_state, copy_state);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(next_pair, copy_state));
// KALDI_ASSERT(result.second == true);
// state_queue.push(next_pair);
// states[orig_state] = copy_state;
// } else {
// copy_state = iter->second;
// }
// }
// // Create an arc from state_map[s] to copy_state in the expanded lattice.
// expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight,
// copy_state));
// // Compute forward logprobs alpha for the expanded lattice.
// if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha
// // for copy_state in <expand_clat>.
// alpha.push_back(this_alpha);
// } else { // Accumulate alpha.
// alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha);
// }
// }
// } // end while
// }
//
//
// void CompactLatticeBestCostsAndTracebacks(
// const CompactLattice &clat,
// CostTraceType *forward_best_cost_and_pred,
// CostTraceType *backward_best_cost_and_pred) {
//
// // typedef the arc, weight types
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// forward_best_cost_and_pred->clear();
// backward_best_cost_and_pred->clear();
// forward_best_cost_and_pred->resize(clat.NumStates());
// backward_best_cost_and_pred->resize(clat.NumStates());
// // Initialize the cost and predecessor state for each state.
// for (StateId s = 0; s < clat.NumStates(); s++) {
// (*forward_best_cost_and_pred)[s].first =
// std::numeric_limits<double>::infinity();
// (*backward_best_cost_and_pred)[s].first =
// std::numeric_limits<double>::infinity();
// (*forward_best_cost_and_pred)[s].second = fst::kNoStateId;
// (*backward_best_cost_and_pred)[s].second = fst::kNoStateId;
// }
//
// StateId start_state = clat.Start();
// (*forward_best_cost_and_pred)[start_state].first = 0;
// // Transverse the lattice forwardly to compute the best cost from the start
// // state to each state and the best predecessor state of each state.
// for (StateId s = 0; s < clat.NumStates(); s++) {
// double cur_cost = (*forward_best_cost_and_pred)[s].first;
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double next_cost = cur_cost + ConvertToCost(arc.weight);
// if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) {
// (*forward_best_cost_and_pred)[arc.nextstate].first = next_cost;
// (*forward_best_cost_and_pred)[arc.nextstate].second = s;
// }
// }
// }
// // Transverse the lattice backwardly to compute the best cost from a final
// // state to each state and the best predecessor state of each state.
// for (StateId s = clat.NumStates() - 1; s >= 0; s--) {
// double this_cost = ConvertToCost(clat.Final(s));
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first +
// ConvertToCost(arc.weight);
// if (next_cost < this_cost) {
// this_cost = next_cost;
// (*backward_best_cost_and_pred)[s].second = arc.nextstate;
// }
// }
// (*backward_best_cost_and_pred)[s].first = this_cost;
// }
// }
//
//
// void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
// CompactLattice *clat) {
// if (clat->Start() == fst::kNoStateId) return;
// // Make sure the input lattice is topologically sorted.
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// KALDI_LOG << "Topsort this lattice.";
// if (!TopSort(clat))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// AddNnlmScoreToCompactLattice(nnlm_scores, clat);
// return;
// }
//
// // typedef the arc, weight types
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
// typedef std::pair<int32, int32> StatePair;
//
// int32 num_states = clat->NumStates();
// unordered_map<StatePair, bool, PairHasher<int32> > final_state_check;
// for (StateId s = 0; s < num_states; s++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// Arc arc(aiter.Value());
// StatePair arc_index = std::make_pair(static_cast<int32>(s),
// static_cast<int32>(arc.nextstate));
// MapT::const_iterator it = nnlm_scores.find(arc_index);
// double nnlm_score;
// if (it != nnlm_scores.end())
// nnlm_score = it->second;
// else
// KALDI_ERR << "Some arc does not have neural language model score.";
// if (arc.ilabel != 0) { // if there is a word on this arc
// LatticeWeight weight = arc.weight.Weight();
// // Add associated neural LM score to each arc.
// weight.SetValue1(weight.Value1() + nnlm_score);
// arc.weight.SetWeight(weight);
// aiter.SetValue(arc);
// }
// Weight clat_final = clat->Final(arc.nextstate);
// StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate);
// // Add neural LM scores to each final state only once.
// if (clat_final != CompactLatticeWeight::Zero() &&
// final_state_check.find(final_pair) == final_state_check.end()) {
// MapT::const_iterator final_it = nnlm_scores.find(final_pair);
// double final_nnlm_score = 0.0;
// if (final_it != nnlm_scores.end())
// final_nnlm_score = final_it->second;
// // Add neural LM scores to the final weight.
// Weight final_weight(LatticeWeight(clat_final.Weight().Value1() +
// final_nnlm_score,
// clat_final.Weight().Value2()),
// clat_final.String());
// clat->SetFinal(arc.nextstate, final_weight);
// final_state_check[final_pair] = true;
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// int32 num_states = clat->NumStates();
//
// //scan the lattice
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done(); aiter.Next()) {
//
// Arc arc(aiter.Value());
//
// if (arc.ilabel != 0) { // if there is a word on this arc
// LatticeWeight weight = arc.weight.Weight();
// // add word insertion penalty to lattice
// weight.SetValue1( weight.Value1() + word_ins_penalty);
// arc.weight.SetWeight(weight);
// aiter.SetValue(arc);
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// struct ClatRescoreTuple {
// ClatRescoreTuple(int32 state, int32 arc, int32 tid):
// state_id(state), arc_id(arc), tid(tid) { }
// int32 state_id;
// int32 arc_id;
// int32 tid;
// };
//
// /** RescoreCompactLatticeInternal is the internal code for both
// RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
// RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0.
// */
// bool RescoreCompactLatticeInternal(
// const TransitionModel *tmodel,
// BaseFloat speedup_factor,
// DecodableInterface *decodable,
// CompactLattice *clat) {
// KALDI_ASSERT(speedup_factor >= 1.0);
// if (clat->NumStates() == 0) {
// KALDI_WARN << "Rescoring empty lattice";
// return false;
// }
// if (!clat->Properties(fst::kTopSorted, true)) {
// if (fst::TopSort(clat) == false) {
// KALDI_WARN << "Cycles detected in lattice.";
// return false;
// }
// }
// std::vector<int32> state_times;
// int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
//
// std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
//
// int32 num_states = clat->NumStates();
// KALDI_ASSERT(num_states == state_times.size());
// for (size_t state = 0; state < num_states; state++) {
// KALDI_ASSERT(state_times[state] >= 0);
// int32 t = state_times[state];
// int32 arc_id = 0;
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done(); aiter.Next(), arc_id++) {
// CompactLatticeArc arc = aiter.Value();
// std::vector<int32> arc_string = arc.weight.String();
//
// for (size_t offset = 0; offset < arc_string.size(); offset++) {
// if (t < utt_len) { // end state may be past this..
// int32 tid = arc_string[offset];
// time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid));
// } else {
// if (t != utt_len) {
// KALDI_WARN << "There appears to be lattice/feature mismatch, "
// << "aborting.";
// return false;
// }
// }
// }
// }
// if (clat->Final(state) != CompactLatticeWeight::Zero()) {
// arc_id = -1;
// std::vector<int32> arc_string = clat->Final(state).String();
// for (size_t offset = 0; offset < arc_string.size(); offset++) {
// KALDI_ASSERT(t + offset < utt_len); // already checked in
// // CompactLatticeStateTimes, so would be code error.
// time_to_state[t+offset].push_back(
// ClatRescoreTuple(state, arc_id, arc_string[offset]));
// }
// }
// }
//
// for (int32 t = 0; t < utt_len; t++) {
// if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
// KALDI_WARN << "Features are too short for lattice: utt-len is "
// << utt_len << ", " << t << " is last frame";
// return false;
// }
// // frame_scale is the scale we put on the computed acoustic probs for this
// // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
// // the "speedup" code). For frames with multiple pdf-ids it will be one.
// // For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
// // with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
// // we can avoid computing the probabilities.
// BaseFloat frame_scale = 1.0;
// KALDI_ASSERT(!time_to_state[t].empty());
// if (tmodel != NULL) {
// int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
// bool frame_has_multiple_pdfs = false;
// for (size_t i = 1; i < time_to_state[t].size(); i++) {
// if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
// frame_has_multiple_pdfs = true;
// break;
// }
// }
// if (frame_has_multiple_pdfs) {
// frame_scale = 1.0;
// } else {
// if (WithProb(1.0 / speedup_factor)) {
// frame_scale = speedup_factor;
// } else {
// frame_scale = 0.0;
// }
// }
// if (frame_scale == 0.0)
// continue; // the code below would be pointless.
// }
//
// for (size_t i = 0; i < time_to_state[t].size(); i++) {
// int32 state = time_to_state[t][i].state_id;
// int32 arc_id = time_to_state[t][i].arc_id;
// int32 tid = time_to_state[t][i].tid;
//
// if (arc_id == -1) { // Final state
// // Access the trans_id
// CompactLatticeWeight curr_clat_weight = clat->Final(state);
//
// // Calculate likelihood
// BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// // update weight
// CompactLatticeWeight new_clat_weight = curr_clat_weight;
// LatticeWeight new_lat_weight = new_clat_weight.Weight();
// new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2());
// new_clat_weight.SetWeight(new_lat_weight);
// clat->SetFinal(state, new_clat_weight);
// } else {
// fst::MutableArcIterator<CompactLattice> aiter(clat, state);
//
// aiter.Seek(arc_id);
// CompactLatticeArc arc = aiter.Value();
//
// // Calculate likelihood
// BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// // update weight
// LatticeWeight new_weight = arc.weight.Weight();
// new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
// arc.weight.SetWeight(new_weight);
// aiter.SetValue(arc);
// }
// }
// }
// return true;
// }
//
//
// bool RescoreCompactLatticeSpeedup(
// const TransitionModel &tmodel,
// BaseFloat speedup_factor,
// DecodableInterface *decodable,
// CompactLattice *clat) {
// return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat);
// }
//
// bool RescoreCompactLattice(DecodableInterface *decodable,
// CompactLattice *clat) {
// return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
// }
//
//
// bool RescoreLattice(DecodableInterface *decodable,
// Lattice *lat) {
// if (lat->NumStates() == 0) {
// KALDI_WARN << "Rescoring empty lattice";
// return false;
// }
// if (!lat->Properties(fst::kTopSorted, true)) {
// if (fst::TopSort(lat) == false) {
// KALDI_WARN << "Cycles detected in lattice.";
// return false;
// }
// }
// std::vector<int32> state_times;
// int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
//
// std::vector<std::vector<int32> > time_to_state(utt_len );
//
// int32 num_states = lat->NumStates();
// KALDI_ASSERT(num_states == state_times.size());
// for (size_t state = 0; state < num_states; state++) {
// int32 t = state_times[state];
// // Don't check t >= 0 because non-accessible states could have t = -1.
// KALDI_ASSERT(t <= utt_len);
// if (t >= 0 && t < utt_len)
// time_to_state[t].push_back(state);
// }
//
// for (int32 t = 0; t < utt_len; t++) {
// if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
// KALDI_WARN << "Features are too short for lattice: utt-len is "
// << utt_len << ", " << t << " is last frame";
// return false;
// }
// for (size_t i = 0; i < time_to_state[t].size(); i++) {
// int32 state = time_to_state[t][i];
// for (fst::MutableArcIterator<Lattice> aiter(lat, state);
// !aiter.Done(); aiter.Next()) {
// LatticeArc arc = aiter.Value();
// if (arc.ilabel != 0) {
// int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
// // have to be a transition-id, just whatever the Decodable
// // object is expecting, but it's normally a transition-id.
//
// BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
// arc.weight.SetValue2(-log_like + arc.weight.Value2());
// aiter.SetValue(arc);
// }
// }
// }
// }
// return true;
// }
//
//
// BaseFloat LatticeForwardBackwardMmi(
// const TransitionModel &tmodel,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// bool drop_frames,
// bool convert_to_pdf_ids,
// bool cancel,
// Posterior *post) {
// // First compute the MMI posteriors.
//
// Posterior den_post;
// BaseFloat ans = LatticeForwardBackward(lat,
// &den_post,
// NULL);
//
// Posterior num_post;
// AlignmentToPosterior(num_ali, &num_post);
//
// // Now negate the MMI posteriors and add the numerator
// // posteriors.
// ScalePosterior(-1.0, &den_post);
//
// if (convert_to_pdf_ids) {
// Posterior num_tmp;
// ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
// num_tmp.swap(num_post);
// Posterior den_tmp;
// ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
// den_tmp.swap(den_post);
// }
//
// MergePosteriors(num_post, den_post,
// cancel, drop_frames, post);
//
// return ans;
// }
//
//
// int32 LongestSentenceLength(const Lattice &lat) {
// typedef Lattice::Arc Arc;
// typedef Arc::Label Label;
// typedef Arc::StateId StateId;
//
// if (lat.Properties(fst::kTopSorted, true) == 0) {
// Lattice lat_copy(lat);
// if (!TopSort(&lat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// return LongestSentenceLength(lat_copy);
// }
// std::vector<int32> max_length(lat.NumStates(), 0);
// int32 lattice_max_length = 0;
// for (StateId s = 0; s < lat.NumStates(); s++) {
// int32 this_max_length = max_length[s];
// for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// bool arc_has_word = (arc.olabel != 0);
// StateId nextstate = arc.nextstate;
// KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
// if (arc_has_word) {
// // A lattice should ideally not have cycles anyway; a cycle with a word
// // on is something very bad.
// KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length + 1);
// } else {
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length);
// }
// }
// if (lat.Final(s) != LatticeWeight::Zero())
// lattice_max_length = std::max(lattice_max_length, max_length[s]);
// }
// return lattice_max_length;
// }
//
// int32 LongestSentenceLength(const CompactLattice &clat) {
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef Arc::StateId StateId;
//
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// return LongestSentenceLength(clat_copy);
// }
// std::vector<int32> max_length(clat.NumStates(), 0);
// int32 lattice_max_length = 0;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// int32 this_max_length = max_length[s];
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
// // also note: for normal CompactLattice, e.g. as produced by
// // determinization, all arcs will have nonzero labels, but the user might
// // decide to remplace some of the labels with zero for some reason, and we
// // want to support this.
// StateId nextstate = arc.nextstate;
// KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
// KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
// if (arc_has_word)
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length + 1);
// else
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length);
// }
// if (clat.Final(s) != CompactLatticeWeight::Zero())
// lattice_max_length = std::max(lattice_max_length, max_length[s]);
// }
// return lattice_max_length;
// }
//
// void ComposeCompactLatticeDeterministic(
// const CompactLattice& clat,
// fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
// CompactLattice* composed_clat) {
// // StdFst::Arc and CompactLatticeArc has the same StateId type.
// typedef fst::StdArc::StateId StateId;
// typedef fst::StdArc::Weight Weight1;
// typedef CompactLatticeArc::Weight Weight2;
// typedef std::pair<StateId, StateId> StatePair;
// typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
// typedef MapType::iterator IterType;
//
// // Empties the output FST.
// KALDI_ASSERT(composed_clat != NULL);
// composed_clat->DeleteStates();
//
// MapType state_map;
// std::queue<StatePair> state_queue;
//
// // Sets start state in <composed_clat>.
// StateId start_state = composed_clat->AddState();
// StatePair start_pair(clat.Start(), det_fst->Start());
// composed_clat->SetStart(start_state);
// state_queue.push(start_pair);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(start_pair, start_state));
// KALDI_ASSERT(result.second == true);
//
// // Starts composition here.
// while (!state_queue.empty()) {
// // Gets the first state in the queue.
// StatePair s = state_queue.front();
// StateId s1 = s.first;
// StateId s2 = s.second;
// state_queue.pop();
//
//
// Weight2 clat_final = clat.Final(s1);
// if (clat_final.Weight().Value1() !=
// std::numeric_limits<BaseFloat>::infinity()) {
// // Test for whether the final-prob of state s1 was zero.
// Weight1 det_fst_final = det_fst->Final(s2);
// if (det_fst_final.Value() !=
// std::numeric_limits<BaseFloat>::infinity()) {
// // Test for whether the final-prob of state s2 was zero. If neither
// // source-state final prob was zero, then we should create final state
// // in fst_composed. We compute the product manually since this is more
// // efficient.
// Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
// det_fst_final.Value(),
// clat_final.Weight().Value2()),
// clat_final.String());
// // we can assume final_weight is not Zero(), since neither of
// // the sources was zero.
// KALDI_ASSERT(state_map.find(s) != state_map.end());
// composed_clat->SetFinal(state_map[s], final_weight);
// }
// }
//
// // Loops over pair of edges at s1 and s2.
// for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
// !aiter.Done(); aiter.Next()) {
// const CompactLatticeArc& arc1 = aiter.Value();
// fst::StdArc arc2;
// StateId next_state1 = arc1.nextstate, next_state2;
// bool matched = false;
//
// if (arc1.olabel == 0) {
// // If the symbol on <arc1> is <epsilon>, we transit to the next state
// // for <clat>, but keep <det_fst> at the current state.
// matched = true;
// next_state2 = s2;
// } else {
// // Otherwise try to find the matched arc in <det_fst>.
// matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
// if (matched) {
// next_state2 = arc2.nextstate;
// }
// }
//
// // If matched arc is found in <det_fst>, then we have to add new arcs to
// // <composed_clat>.
// if (matched) {
// StatePair next_state_pair(next_state1, next_state2);
// IterType siter = state_map.find(next_state_pair);
// StateId next_state;
//
// // Adds composed state to <state_map>.
// if (siter == state_map.end()) {
// // If the composed state has not been created yet, create it.
// next_state = composed_clat->AddState();
// std::pair<const StatePair, StateId> next_state_map(next_state_pair,
// next_state);
// std::pair<IterType, bool> result = state_map.insert(next_state_map);
// KALDI_ASSERT(result.second);
// state_queue.push(next_state_pair);
// } else {
// // If the composed state is already in <state_map>, we can directly
// // use that.
// next_state = siter->second;
// }
//
// // Adds arc to <composed_clat>.
// if (arc1.olabel == 0) {
// composed_clat->AddArc(state_map[s],
// CompactLatticeArc(arc1.ilabel, 0,
// arc1.weight, next_state));
// } else {
// Weight2 composed_weight(
// LatticeWeight(arc1.weight.Weight().Value1() +
// arc2.weight.Value(),
// arc1.weight.Weight().Value2()),
// arc1.weight.String());
// composed_clat->AddArc(state_map[s],
// CompactLatticeArc(arc1.ilabel, arc2.olabel,
// composed_weight, next_state));
// }
// }
// }
// }
// fst::Connect(composed_clat);
// }
//
//
// void ComputeAcousticScoresMap(
// const Lattice &lat,
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > *acoustic_scores) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// acoustic_scores->clear();
//
// std::vector<int32> state_times;
// LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
//
// KALDI_ASSERT(lat.Start() == 0);
//
// for (StateId s = 0; s < lat.NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// const LatticeWeight &weight = arc.weight;
//
// int32 tid = arc.ilabel;
//
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid));
// if (it == acoustic_scores->end()) {
// acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
// std::make_pair(weight.Value2(), 1)));
// } else {
// if (it->second.second == 2
// && it->second.first / it->second.second != weight.Value2()) {
// KALDI_VLOG(2) << "Transitions on the same frame have different "
// << "acoustic costs for tid " << tid << "; "
// << it->second.first / it->second.second
// << " vs " << weight.Value2();
// }
// it->second.first += weight.Value2();
// it->second.second++;
// }
// } else {
// // Arcs with epsilon input label (tid) must have 0 acoustic cost
// KALDI_ASSERT(weight.Value2() == 0);
// }
// }
//
// LatticeWeight f = lat.Final(s);
// if (f != LatticeWeight::Zero()) {
// // Final acoustic cost must be 0 as we are reading from
// // non-determinized, non-compact lattice
// KALDI_ASSERT(f.Value2() == 0.0);
// }
// }
// }
//
// void ReplaceAcousticScoresFromMap(
// const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > &acoustic_scores,
// Lattice *lat) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// TopSortLatticeIfNeeded(lat);
//
// std::vector<int32> state_times;
// LatticeStateTimes(*lat, &state_times);
//
// KALDI_ASSERT(lat->Start() == 0);
//
// for (StateId s = 0; s < lat->NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::MutableArcIterator<Lattice> aiter(lat, s);
// !aiter.Done(); aiter.Next()) {
// Arc arc(aiter.Value());
//
// int32 tid = arc.ilabel;
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid));
// if (it == acoustic_scores.end()) {
// KALDI_ERR << "Could not find tid " << tid << " at time " << t
// << " in the acoustic scores map.";
// } else {
// arc.weight.SetValue2(it->second.first / it->second.second);
// }
// } else {
// // For epsilon arcs, set acoustic cost to 0.0
// arc.weight.SetValue2(0.0);
// }
// aiter.SetValue(arc);
// }
//
// LatticeWeight f = lat->Final(s);
// if (f != LatticeWeight::Zero()) {
// // Set final acoustic cost to 0.0
// f.SetValue2(0.0);
// lat->SetFinal(s, f);
// }
// }
// }
} // namespace kaldi