// lat/lattice-functions.cc // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal) // 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng; // Bagher BabaAli // 2013 Cisco Systems (author: Neha Agrawal) [code modified // from original code in ../gmmbin/gmm-rescore-lattice.cc] // 2014 Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "lat/lattice-functions.h" // #include "hmm/transition-model.h" // #include "util/stl-utils.h" #include "base/kaldi-math.h" // #include "hmm/hmm-utils.h" namespace kaldi { using std::map; using std::vector; // void GetPerFrameAcousticCosts(const Lattice &nbest, // Vector *per_frame_loglikes) { // using namespace fst; // typedef Lattice::Arc::Weight Weight; // vector loglikes; // // int32 cur_state = nbest.Start(); // int32 prev_frame = -1; // BaseFloat eps_acwt = 0.0; // while(1) { // Weight w = nbest.Final(cur_state); // if (w != Weight::Zero()) { // KALDI_ASSERT(nbest.NumArcs(cur_state) == 0); // if (per_frame_loglikes != NULL) { // SubVector subvec(&(loglikes[0]), loglikes.size()); // Vector vec(subvec); // *per_frame_loglikes = vec; // } // break; // } else { // KALDI_ASSERT(nbest.NumArcs(cur_state) == 1); // fst::ArcIterator iter(nbest, cur_state); // const Lattice::Arc &arc = iter.Value(); // BaseFloat acwt = arc.weight.Value2(); // if (arc.ilabel != 0) { // if (eps_acwt > 0) { // acwt += eps_acwt; // eps_acwt = 0.0; // } // loglikes.push_back(acwt); // prev_frame++; // } else if (acwt == acwt){ // if (prev_frame > -1) { // loglikes[prev_frame] += acwt; // } else { // eps_acwt += acwt; // } // } // cur_state = arc.nextstate; // } // } // } // // int32 LatticeStateTimes(const Lattice &lat, vector *times) { // if (!lat.Properties(fst::kTopSorted, true)) // KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ASSERT(lat.Start() == 0); // int32 num_states = lat.NumStates(); // times->clear(); // times->resize(num_states, -1); // (*times)[0] = 0; // for (int32 state = 0; state < num_states; state++) { // int32 cur_time = (*times)[state]; // for (fst::ArcIterator aiter(lat, state); !aiter.Done(); // aiter.Next()) { // const LatticeArc &arc = aiter.Value(); // // if (arc.ilabel != 0) { // Non-epsilon input label on arc // // next time instance // if ((*times)[arc.nextstate] == -1) { // (*times)[arc.nextstate] = cur_time + 1; // } else { // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1); // } // } else { // epsilon input label on arc // // Same time instance // if ((*times)[arc.nextstate] == -1) // (*times)[arc.nextstate] = cur_time; // else // KALDI_ASSERT((*times)[arc.nextstate] == cur_time); // } // } // } // return (*std::max_element(times->begin(), times->end())); // } // // int32 CompactLatticeStateTimes(const CompactLattice &lat, // vector *times) { // if (!lat.Properties(fst::kTopSorted, true)) // KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ASSERT(lat.Start() == 0); // int32 num_states = lat.NumStates(); // times->clear(); // times->resize(num_states, -1); // (*times)[0] = 0; // int32 utt_len = -1; // for (int32 state = 0; state < num_states; state++) { // int32 cur_time = (*times)[state]; // for (fst::ArcIterator aiter(lat, state); !aiter.Done(); // aiter.Next()) { // const CompactLatticeArc &arc = aiter.Value(); // int32 arc_len = static_cast(arc.weight.String().size()); // if ((*times)[arc.nextstate] == -1) // (*times)[arc.nextstate] = cur_time + arc_len; // else // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len); // } // if (lat.Final(state) != CompactLatticeWeight::Zero()) { // int32 this_utt_len = (*times)[state] + lat.Final(state).String().size(); // if (utt_len == -1) utt_len = this_utt_len; // else { // if (this_utt_len != utt_len) { // KALDI_WARN << "Utterance does not " // "seem to have a consistent length."; // utt_len = std::max(utt_len, this_utt_len); // } // } // } // } // if (utt_len == -1) { // KALDI_WARN << "Utterance does not have a final-state."; // return 0; // } // return utt_len; // } // // bool ComputeCompactLatticeAlphas(const CompactLattice &clat, // vector *alpha) { // using namespace fst; // // // typedef the arc, weight types // typedef CompactLattice::Arc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // //Make sure the lattice is topologically sorted. // if (clat.Properties(fst::kTopSorted, true) == 0) { // KALDI_WARN << "Input lattice must be topologically sorted."; // return false; // } // if (clat.Start() != 0) { // KALDI_WARN << "Input lattice must start from state 0."; // return false; // } // // int32 num_states = clat.NumStates(); // (*alpha).resize(0); // (*alpha).resize(num_states, kLogZeroDouble); // // // Now propagate alphas forward. Note that we don't acount the weight of the // // final state to alpha[final_state] -- we acount it to beta[final_state]; // (*alpha)[0] = 0.0; // for (StateId s = 0; s < num_states; s++) { // double this_alpha = (*alpha)[s]; // for (ArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -(arc.weight.Weight().Value1() + // arc.weight.Weight().Value2()); // (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate], // this_alpha + arc_like); // } // } // // return true; // } // // bool ComputeCompactLatticeBetas(const CompactLattice &clat, // vector *beta) { // using namespace fst; // // // typedef the arc, weight types // typedef CompactLattice::Arc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // // Make sure the lattice is topologically sorted. // if (clat.Properties(fst::kTopSorted, true) == 0) { // KALDI_WARN << "Input lattice must be topologically sorted."; // return false; // } // if (clat.Start() != 0) { // KALDI_WARN << "Input lattice must start from state 0."; // return false; // } // // int32 num_states = clat.NumStates(); // (*beta).resize(0); // (*beta).resize(num_states, kLogZeroDouble); // // // Now propagate betas backward. Note that beta[final_state] contains the // // weight of the final state in the lattice -- compare that with alpha. // for (StateId s = num_states-1; s >= 0; s--) { // Weight f = clat.Final(s); // double this_beta = -(f.Weight().Value1()+f.Weight().Value2()); // for (ArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -(arc.weight.Weight().Value1() + // arc.weight.Weight().Value2()); // double arc_beta = (*beta)[arc.nextstate] + arc_like; // this_beta = LogAdd(this_beta, arc_beta); // } // (*beta)[s] = this_beta; // } // // return true; // } template // could be Lattice or CompactLattice bool PruneLattice(BaseFloat beam, LatType *lat) { typedef typename LatType::Arc Arc; typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; KALDI_ASSERT(beam > 0.0); if (!lat->Properties(fst::kTopSorted, true)) { if (fst::TopSort(lat) == false) { KALDI_WARN << "Cycles detected in lattice"; return false; } } // We assume states before "start" are not reachable, since // the lattice is topologically sorted. int32 start = lat->Start(); int32 num_states = lat->NumStates(); if (num_states == 0) return false; std::vector forward_cost(num_states, std::numeric_limits::infinity()); // viterbi forward. forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be // less than this. double best_final_cost = std::numeric_limits::infinity(); // Update the forward probs. // Thanks to Jing Zheng for finding a bug here. for (int32 state = 0; state < num_states; state++) { double this_forward_cost = forward_cost[state]; for (fst::ArcIterator aiter(*lat, state); !aiter.Done(); aiter.Next()) { const Arc &arc(aiter.Value()); StateId nextstate = arc.nextstate; KALDI_ASSERT(nextstate > state && nextstate < num_states); double next_forward_cost = this_forward_cost + ConvertToCost(arc.weight); if (forward_cost[nextstate] > next_forward_cost) forward_cost[nextstate] = next_forward_cost; } Weight final_weight = lat->Final(state); double this_final_cost = this_forward_cost + ConvertToCost(final_weight); if (this_final_cost < best_final_cost) best_final_cost = this_final_cost; } int32 bad_state = lat->AddState(); // this state is not final. double cutoff = best_final_cost + beam; // Go backwards updating the backward probs (which share memory with the // forward probs), and pruning arcs and deleting final-probs. We prune arcs // by making them point to the non-final state "bad_state". We'll then use // Trim() to remove unnecessary arcs and states. [this is just easier than // doing it ourselves.] std::vector &backward_cost(forward_cost); for (int32 state = num_states - 1; state >= 0; state--) { double this_forward_cost = forward_cost[state]; double this_backward_cost = ConvertToCost(lat->Final(state)); if (this_backward_cost + this_forward_cost > cutoff && this_backward_cost != std::numeric_limits::infinity()) lat->SetFinal(state, Weight::Zero()); for (fst::MutableArcIterator aiter(lat, state); !aiter.Done(); aiter.Next()) { Arc arc(aiter.Value()); StateId nextstate = arc.nextstate; KALDI_ASSERT(nextstate > state && nextstate < num_states); double arc_cost = ConvertToCost(arc.weight), arc_backward_cost = arc_cost + backward_cost[nextstate], this_fb_cost = this_forward_cost + arc_backward_cost; if (arc_backward_cost < this_backward_cost) this_backward_cost = arc_backward_cost; if (this_fb_cost > cutoff) { // Prune the arc. arc.nextstate = bad_state; aiter.SetValue(arc); } } backward_cost[state] = this_backward_cost; } fst::Connect(lat); return (lat->NumStates() > 0); } // instantiate the template for lattice and CompactLattice. template bool PruneLattice(BaseFloat beam, Lattice *lat); template bool PruneLattice(BaseFloat beam, CompactLattice *lat); // BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post, // double *acoustic_like_sum) { // // Note, Posterior is defined as follows: Indexed [frame], then a list // // of (transition-id, posterior-probability) pairs. // // typedef std::vector > > Posterior; // using namespace fst; // typedef Lattice::Arc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // if (acoustic_like_sum) *acoustic_like_sum = 0.0; // // // Make sure the lattice is topologically sorted. // if (lat.Properties(fst::kTopSorted, true) == 0) // KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ASSERT(lat.Start() == 0); // // int32 num_states = lat.NumStates(); // vector state_times; // int32 max_time = LatticeStateTimes(lat, &state_times); // std::vector alpha(num_states, kLogZeroDouble); // std::vector &beta(alpha); // we re-use the same memory for // // this, but it's semantically distinct so we name it differently. // double tot_forward_prob = kLogZeroDouble; // // post->clear(); // post->resize(max_time); // // alpha[0] = 0.0; // // Propagate alphas forward. // for (StateId s = 0; s < num_states; s++) { // double this_alpha = alpha[s]; // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight); // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like); // } // Weight f = lat.Final(s); // if (f != Weight::Zero()) { // double final_like = this_alpha - (f.Value1() + f.Value2()); // tot_forward_prob = LogAdd(tot_forward_prob, final_like); // KALDI_ASSERT(state_times[s] == max_time && // "Lattice is inconsistent (final-prob not at max_time)"); // } // } // for (StateId s = num_states-1; s >= 0; s--) { // Weight f = lat.Final(s); // double this_beta = -(f.Value1() + f.Value2()); // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight), // arc_beta = beta[arc.nextstate] + arc_like; // this_beta = LogAdd(this_beta, arc_beta); // int32 transition_id = arc.ilabel; // // // The following "if" is an optimization to avoid un-needed exp(). // if (transition_id != 0 || acoustic_like_sum != NULL) { // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob); // // if (transition_id != 0) // Arc has a transition-id on it [not epsilon] // (*post)[state_times[s]].push_back(std::make_pair(transition_id, // static_cast(posterior))); // if (acoustic_like_sum != NULL) // *acoustic_like_sum -= posterior * arc.weight.Value2(); // } // } // if (acoustic_like_sum != NULL && f != Weight::Zero()) { // double final_logprob = - ConvertToCost(f), // posterior = Exp(alpha[s] + final_logprob - tot_forward_prob); // *acoustic_like_sum -= posterior * f.Value2(); // } // beta[s] = this_beta; // } // double tot_backward_prob = beta[0]; // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { // KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob // << ", while total backward probability = " << tot_backward_prob; // } // // Now combine any posteriors with the same transition-id. // for (int32 t = 0; t < max_time; t++) // MergePairVectorSumming(&((*post)[t])); // return tot_backward_prob; // } // // // void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans, // const vector &silence_phones, // vector< std::set > *active_phones) { // KALDI_ASSERT(IsSortedAndUniq(silence_phones)); // vector state_times; // int32 num_states = lat.NumStates(); // int32 max_time = LatticeStateTimes(lat, &state_times); // active_phones->clear(); // active_phones->resize(max_time); // for (int32 state = 0; state < num_states; state++) { // int32 cur_time = state_times[state]; // for (fst::ArcIterator aiter(lat, state); !aiter.Done(); // aiter.Next()) { // const LatticeArc &arc = aiter.Value(); // if (arc.ilabel != 0) { // Non-epsilon arc // int32 phone = trans.TransitionIdToPhone(arc.ilabel); // if (!std::binary_search(silence_phones.begin(), // silence_phones.end(), phone)) // (*active_phones)[cur_time].insert(phone); // } // } // end looping over arcs // } // end looping over states // } // // void ConvertLatticeToPhones(const TransitionModel &trans, // Lattice *lat) { // typedef LatticeArc Arc; // int32 num_states = lat->NumStates(); // for (int32 state = 0; state < num_states; state++) { // for (fst::MutableArcIterator aiter(lat, state); !aiter.Done(); // aiter.Next()) { // Arc arc(aiter.Value()); // arc.olabel = 0; // remove any word. // if ((arc.ilabel != 0) // has a transition-id on input.. // && (trans.TransitionIdToHmmState(arc.ilabel) == 0) // && (!trans.IsSelfLoop(arc.ilabel))) { // // && trans.IsFinal(arc.ilabel)) // there is one of these per phone... // arc.olabel = trans.TransitionIdToPhone(arc.ilabel); // } // aiter.SetValue(arc); // } // end looping over arcs // } // end looping over states // } // // // static inline double LogAddOrMax(bool viterbi, double a, double b) { // if (viterbi) // return std::max(a, b); // else // return LogAdd(a, b); // } // // template // double ComputeLatticeAlphasAndBetas(const LatticeType &lat, // bool viterbi, // vector *alpha, // vector *beta) { // typedef typename LatticeType::Arc Arc; // typedef typename Arc::Weight Weight; // typedef typename Arc::StateId StateId; // // StateId num_states = lat.NumStates(); // KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted); // KALDI_ASSERT(lat.Start() == 0); // alpha->clear(); // beta->clear(); // alpha->resize(num_states, kLogZeroDouble); // beta->resize(num_states, kLogZeroDouble); // // double tot_forward_prob = kLogZeroDouble; // (*alpha)[0] = 0.0; // // Propagate alphas forward. // for (StateId s = 0; s < num_states; s++) { // double this_alpha = (*alpha)[s]; // for (fst::ArcIterator aiter(lat, s); !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight); // (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate], // this_alpha + arc_like); // } // Weight f = lat.Final(s); // if (f != Weight::Zero()) { // double final_like = this_alpha - ConvertToCost(f); // tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like); // } // } // for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed. // double this_beta = -ConvertToCost(lat.Final(s)); // for (fst::ArcIterator aiter(lat, s); !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight), // arc_beta = (*beta)[arc.nextstate] + arc_like; // this_beta = LogAddOrMax(viterbi, this_beta, arc_beta); // } // (*beta)[s] = this_beta; // } // double tot_backward_prob = (*beta)[lat.Start()]; // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) { // KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob // << ", while total backward probability = " << tot_backward_prob; // } // // Split the difference when returning... they should be the same. // return 0.5 * (tot_backward_prob + tot_forward_prob); // } // // // instantiate the template for Lattice and CompactLattice // template // double ComputeLatticeAlphasAndBetas(const Lattice &lat, // bool viterbi, // vector *alpha, // vector *beta); // // template // double ComputeLatticeAlphasAndBetas(const CompactLattice &lat, // bool viterbi, // vector *alpha, // vector *beta); // // // // /// This is used in CompactLatticeLimitDepth. // struct LatticeArcRecord { // BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc, // // minus the overall best-cost of the lattice. // CompactLatticeArc::StateId state; // state in the lattice. // size_t arc; // arc index within the state. // bool operator < (const LatticeArcRecord &other) const { // return logprob < other.logprob; // } // }; // // void CompactLatticeLimitDepth(int32 max_depth_per_frame, // CompactLattice *clat) { // typedef CompactLatticeArc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // if (clat->Start() == fst::kNoStateId) { // KALDI_WARN << "Limiting depth of empty lattice."; // return; // } // if (clat->Properties(fst::kTopSorted, true) == 0) { // if (!TopSort(clat)) // KALDI_ERR << "Topological sorting of lattice failed."; // } // // vector state_times; // int32 T = CompactLatticeStateTimes(*clat, &state_times); // // // The alpha and beta quantities here are "viterbi" alphas and beta. // std::vector alpha; // std::vector beta; // bool viterbi = true; // double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi, // &alpha, &beta); // // std::vector > arc_records(T); // // StateId num_states = clat->NumStates(); // for (StateId s = 0; s < num_states; s++) { // for (fst::ArcIterator aiter(*clat, s); !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // LatticeArcRecord arc_record; // arc_record.state = s; // arc_record.arc = aiter.Position(); // arc_record.logprob = // (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight)) // - best_prob; // KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative. // int32 num_frames = arc.weight.String().size(), start_t = state_times[s]; // for (int32 t = start_t; t < start_t + num_frames; t++) { // KALDI_ASSERT(t < T); // arc_records[t].push_back(arc_record); // } // } // } // StateId dead_state = clat->AddState(); // A non-coaccesible state which we use // // to remove arcs (make them end // // there). // size_t max_depth = max_depth_per_frame; // for (int32 t = 0; t < T; t++) { // size_t size = arc_records[t].size(); // if (size > max_depth) { // // we sort from worst to best, so we keep the later-numbered ones, // // and delete the lower-numbered ones. // size_t cutoff = size - max_depth; // std::nth_element(arc_records[t].begin(), // arc_records[t].begin() + cutoff, // arc_records[t].end()); // for (size_t index = 0; index < cutoff; index++) { // LatticeArcRecord record(arc_records[t][index]); // fst::MutableArcIterator aiter(clat, record.state); // aiter.Seek(record.arc); // Arc arc = aiter.Value(); // if (arc.nextstate != dead_state) { // not already killed. // arc.nextstate = dead_state; // aiter.SetValue(arc); // } // } // } // } // Connect(clat); // TopSortCompactLatticeIfNeeded(clat); // } // // // void TopSortCompactLatticeIfNeeded(CompactLattice *clat) { // if (clat->Properties(fst::kTopSorted, true) == 0) { // if (fst::TopSort(clat) == false) { // KALDI_ERR << "Topological sorting failed"; // } // } // } // // void TopSortLatticeIfNeeded(Lattice *lat) { // if (lat->Properties(fst::kTopSorted, true) == 0) { // if (fst::TopSort(lat) == false) { // KALDI_ERR << "Topological sorting failed"; // } // } // } // // // /// Returns the depth of the lattice, defined as the average number of // /// arcs crossing any given frame. Returns 1 for empty lattices. // /// Requires that input is topologically sorted. // BaseFloat CompactLatticeDepth(const CompactLattice &clat, // int32 *num_frames) { // typedef CompactLattice::Arc::StateId StateId; // if (clat.Properties(fst::kTopSorted, true) == 0) { // KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically " // << "sorted."; // } // if (clat.Start() == fst::kNoStateId) { // *num_frames = 0; // return 1.0; // } // size_t num_arc_frames = 0; // int32 t; // { // vector state_times; // t = CompactLatticeStateTimes(clat, &state_times); // } // if (num_frames != NULL) // *num_frames = t; // for (StateId s = 0; s < clat.NumStates(); s++) { // for (fst::ArcIterator aiter(clat, s); !aiter.Done(); // aiter.Next()) { // const CompactLatticeArc &arc = aiter.Value(); // num_arc_frames += arc.weight.String().size(); // } // num_arc_frames += clat.Final(s).String().size(); // } // return num_arc_frames / static_cast(t); // } // // // void CompactLatticeDepthPerFrame(const CompactLattice &clat, // std::vector *depth_per_frame) { // typedef CompactLattice::Arc::StateId StateId; // if (clat.Properties(fst::kTopSorted, true) == 0) { // KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not " // << "topologically sorted."; // } // if (clat.Start() == fst::kNoStateId) { // depth_per_frame->clear(); // return; // } // vector state_times; // int32 T = CompactLatticeStateTimes(clat, &state_times); // // depth_per_frame->clear(); // if (T <= 0) { // return; // } else { // depth_per_frame->resize(T, 0); // for (StateId s = 0; s < clat.NumStates(); s++) { // int32 start_time = state_times[s]; // for (fst::ArcIterator aiter(clat, s); !aiter.Done(); // aiter.Next()) { // const CompactLatticeArc &arc = aiter.Value(); // int32 len = arc.weight.String().size(); // for (int32 t = start_time; t < start_time + len; t++) { // KALDI_ASSERT(t < T); // (*depth_per_frame)[t]++; // } // } // int32 final_len = clat.Final(s).String().size(); // for (int32 t = start_time; t < start_time + final_len; t++) { // KALDI_ASSERT(t < T); // (*depth_per_frame)[t]++; // } // } // } // } // // // // void ConvertCompactLatticeToPhones(const TransitionModel &trans, // CompactLattice *clat) { // typedef CompactLatticeArc Arc; // typedef Arc::Weight Weight; // int32 num_states = clat->NumStates(); // for (int32 state = 0; state < num_states; state++) { // for (fst::MutableArcIterator aiter(clat, state); // !aiter.Done(); // aiter.Next()) { // Arc arc(aiter.Value()); // std::vector phone_seq; // const std::vector &tid_seq = arc.weight.String(); // for (std::vector::const_iterator iter = tid_seq.begin(); // iter != tid_seq.end(); ++iter) { // if (trans.IsFinal(*iter))// note: there is one of these per phone... // phone_seq.push_back(trans.TransitionIdToPhone(*iter)); // } // arc.weight.SetString(phone_seq); // aiter.SetValue(arc); // } // end looping over arcs // Weight f = clat->Final(state); // if (f != Weight::Zero()) { // std::vector phone_seq; // const std::vector &tid_seq = f.String(); // for (std::vector::const_iterator iter = tid_seq.begin(); // iter != tid_seq.end(); ++iter) { // if (trans.IsFinal(*iter))// note: there is one of these per phone... // phone_seq.push_back(trans.TransitionIdToPhone(*iter)); // } // f.SetString(phone_seq); // clat->SetFinal(state, f); // } // } // end looping over states // } // // bool LatticeBoost(const TransitionModel &trans, // const std::vector &alignment, // const std::vector &silence_phones, // BaseFloat b, // BaseFloat max_silence_error, // Lattice *lat) { // TopSortLatticeIfNeeded(lat); // // // get all stored properties (test==false means don't test if not known). // uint64 props = lat->Properties(fst::kFstProperties, // false); // // KALDI_ASSERT(IsSortedAndUniq(silence_phones)); // KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0); // vector state_times; // int32 num_states = lat->NumStates(); // int32 num_frames = LatticeStateTimes(*lat, &state_times); // KALDI_ASSERT(num_frames == static_cast(alignment.size())); // for (int32 state = 0; state < num_states; state++) { // int32 cur_time = state_times[state]; // for (fst::MutableArcIterator aiter(lat, state); !aiter.Done(); // aiter.Next()) { // LatticeArc arc = aiter.Value(); // if (arc.ilabel != 0) { // Non-epsilon arc // if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) { // KALDI_WARN << "Lattice has out-of-range transition-ids: " // << "lattice/model mismatch?"; // return false; // } // int32 phone = trans.TransitionIdToPhone(arc.ilabel), // ref_phone = trans.TransitionIdToPhone(alignment[cur_time]); // BaseFloat frame_error; // if (phone == ref_phone) { // frame_error = 0.0; // } else { // an error... // if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone)) // frame_error = max_silence_error; // else // frame_error = 1.0; // } // BaseFloat delta_cost = -b * frame_error; // negative cost if // // frame is wrong, to boost likelihood of arcs with errors on them. // // Add this cost to the graph part. // arc.weight.SetValue1(arc.weight.Value1() + delta_cost); // aiter.SetValue(arc); // } // } // } // // All we changed is the weights, so any properties that were // // known before, are still known, except for whether or not the // // lattice was weighted. // lat->SetProperties(props, // ~(fst::kWeighted|fst::kUnweighted)); // // return true; // } // // // // BaseFloat LatticeForwardBackwardMpeVariants( // const TransitionModel &trans, // const std::vector &silence_phones, // const Lattice &lat, // const std::vector &num_ali, // std::string criterion, // bool one_silence_class, // Posterior *post) { // using namespace fst; // typedef Lattice::Arc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr"); // bool is_mpfe = (criterion == "mpfe"); // // if (lat.Properties(fst::kTopSorted, true) == 0) // KALDI_ERR << "Input lattice must be topologically sorted."; // KALDI_ASSERT(lat.Start() == 0); // // int32 num_states = lat.NumStates(); // vector state_times; // int32 max_time = LatticeStateTimes(lat, &state_times); // KALDI_ASSERT(max_time == static_cast(num_ali.size())); // std::vector alpha(num_states, kLogZeroDouble), // alpha_smbr(num_states, 0), //forward variable for sMBR // beta(num_states, kLogZeroDouble), // beta_smbr(num_states, 0); //backward variable for sMBR // // double tot_forward_prob = kLogZeroDouble; // double tot_forward_score = 0; // // post->clear(); // post->resize(max_time); // // alpha[0] = 0.0; // // First Pass Forward, // for (StateId s = 0; s < num_states; s++) { // double this_alpha = alpha[s]; // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight); // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like); // } // Weight f = lat.Final(s); // if (f != Weight::Zero()) { // double final_like = this_alpha - (f.Value1() + f.Value2()); // tot_forward_prob = LogAdd(tot_forward_prob, final_like); // KALDI_ASSERT(state_times[s] == max_time && // "Lattice is inconsistent (final-prob not at max_time)"); // } // } // // First Pass Backward, // for (StateId s = num_states-1; s >= 0; s--) { // Weight f = lat.Final(s); // double this_beta = -(f.Value1() + f.Value2()); // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight), // arc_beta = beta[arc.nextstate] + arc_like; // this_beta = LogAdd(this_beta, arc_beta); // } // beta[s] = this_beta; // } // // First Pass Forward-Backward Check // double tot_backward_prob = beta[0]; // // may loose the condition somehow here 1e-6 (was 1e-8) // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) { // KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob // << ", while total backward probability = " << tot_backward_prob; // } // // alpha_smbr[0] = 0.0; // // Second Pass Forward, calculate forward for MPFE/SMBR // for (StateId s = 0; s < num_states; s++) { // double this_alpha = alpha[s]; // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight); // double frame_acc = 0.0; // if (arc.ilabel != 0) { // int32 cur_time = state_times[s]; // int32 phone = trans.TransitionIdToPhone(arc.ilabel), // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]); // bool phone_is_sil = std::binary_search(silence_phones.begin(), // silence_phones.end(), // phone), // ref_phone_is_sil = std::binary_search(silence_phones.begin(), // silence_phones.end(), // ref_phone), // both_sil = phone_is_sil && ref_phone_is_sil; // if (!is_mpfe) { // smbr. // int32 pdf = trans.TransitionIdToPdf(arc.ilabel), // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]); // if (!one_silence_class) // old behavior // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0; // else // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0; // } else { // if (!one_silence_class) // old behavior // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0; // else // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0; // } // } // double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]); // alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc); // } // Weight f = lat.Final(s); // if (f != Weight::Zero()) { // double final_like = this_alpha - (f.Value1() + f.Value2()); // double arc_scale = Exp(final_like - tot_forward_prob); // tot_forward_score += arc_scale * alpha_smbr[s]; // KALDI_ASSERT(state_times[s] == max_time && // "Lattice is inconsistent (final-prob not at max_time)"); // } // } // // Second Pass Backward, collect Mpe style posteriors // for (StateId s = num_states-1; s >= 0; s--) { // for (ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_like = -ConvertToCost(arc.weight), // arc_beta = beta[arc.nextstate] + arc_like; // double frame_acc = 0.0; // int32 transition_id = arc.ilabel; // if (arc.ilabel != 0) { // int32 cur_time = state_times[s]; // int32 phone = trans.TransitionIdToPhone(arc.ilabel), // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]); // bool phone_is_sil = std::binary_search(silence_phones.begin(), // silence_phones.end(), phone), // ref_phone_is_sil = std::binary_search(silence_phones.begin(), // silence_phones.end(), // ref_phone), // both_sil = phone_is_sil && ref_phone_is_sil; // if (!is_mpfe) { // smbr. // int32 pdf = trans.TransitionIdToPdf(arc.ilabel), // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]); // if (!one_silence_class) // old behavior // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0; // else // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0; // } else { // if (!one_silence_class) // old behavior // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0; // else // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0; // } // } // double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]); // // check arc_scale NAN, // // this is to prevent partial paths in Lattices // // i.e., paths don't survive to the final state // if (KALDI_ISNAN(arc_scale)) arc_scale = 0; // beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc); // // if (transition_id != 0) { // Arc has a transition-id on it [not epsilon] // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob); // double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate] // - tot_forward_score; // double posterior_smbr = posterior * acc_diff; // (*post)[state_times[s]].push_back(std::make_pair(transition_id, // static_cast(posterior_smbr))); // } // } // } // // //Second Pass Forward Backward check // double tot_backward_score = beta_smbr[0]; // Initial state id == 0 // // may loose the condition somehow here 1e-5/1e-4 // if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) { // KALDI_ERR << "Total forward score over lattice = " << tot_forward_score // << ", while total backward score = " << tot_backward_score; // } // // // Output the computed posteriors // for (int32 t = 0; t < max_time; t++) // MergePairVectorSumming(&((*post)[t])); // return tot_forward_score; // } // // bool CompactLatticeToWordAlignment(const CompactLattice &clat, // std::vector *words, // std::vector *begin_times, // std::vector *lengths) { // words->clear(); // begin_times->clear(); // lengths->clear(); // typedef CompactLattice::Arc Arc; // typedef Arc::Label Label; // typedef CompactLattice::StateId StateId; // typedef CompactLattice::Weight Weight; // using namespace fst; // StateId state = clat.Start(); // int32 cur_time = 0; // if (state == kNoStateId) { // KALDI_WARN << "Empty lattice."; // return false; // } // while (1) { // Weight final = clat.Final(state); // size_t num_arcs = clat.NumArcs(state); // if (final != Weight::Zero()) { // if (num_arcs != 0) { // KALDI_WARN << "Lattice is not linear."; // return false; // } // if (! final.String().empty()) { // KALDI_WARN << "Lattice has alignments on final-weight: probably " // "was not word-aligned (alignments will be approximate)"; // } // return true; // } else { // if (num_arcs != 1) { // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs; // return false; // } // fst::ArcIterator aiter(clat, state); // const Arc &arc = aiter.Value(); // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor. // // Also note: word_id may be zero; we output it anyway. // int32 length = arc.weight.String().size(); // words->push_back(word_id); // begin_times->push_back(cur_time); // lengths->push_back(length); // cur_time += length; // state = arc.nextstate; // } // } // } // // // bool CompactLatticeToWordProns( // const TransitionModel &tmodel, // const CompactLattice &clat, // std::vector *words, // std::vector *begin_times, // std::vector *lengths, // std::vector > *prons, // std::vector > *phone_lengths) { // words->clear(); // begin_times->clear(); // lengths->clear(); // prons->clear(); // phone_lengths->clear(); // typedef CompactLattice::Arc Arc; // typedef Arc::Label Label; // typedef CompactLattice::StateId StateId; // typedef CompactLattice::Weight Weight; // using namespace fst; // StateId state = clat.Start(); // int32 cur_time = 0; // if (state == kNoStateId) { // KALDI_WARN << "Empty lattice."; // return false; // } // while (1) { // Weight final = clat.Final(state); // size_t num_arcs = clat.NumArcs(state); // if (final != Weight::Zero()) { // if (num_arcs != 0) { // KALDI_WARN << "Lattice is not linear."; // return false; // } // if (! final.String().empty()) { // KALDI_WARN << "Lattice has alignments on final-weight: probably " // "was not word-aligned (alignments will be approximate)"; // } // return true; // } else { // if (num_arcs != 1) { // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs; // return false; // } // fst::ArcIterator aiter(clat, state); // const Arc &arc = aiter.Value(); // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor. // // Also note: word_id may be zero; we output it anyway. // int32 length = arc.weight.String().size(); // words->push_back(word_id); // begin_times->push_back(cur_time); // lengths->push_back(length); // const std::vector &arc_alignment = arc.weight.String(); // std::vector > split_alignment; // SplitToPhones(tmodel, arc_alignment, &split_alignment); // std::vector phones(split_alignment.size()); // std::vector plengths(split_alignment.size()); // for (size_t i = 0; i < split_alignment.size(); i++) { // KALDI_ASSERT(!split_alignment[i].empty()); // phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]); // plengths[i] = split_alignment[i].size(); // } // prons->push_back(phones); // phone_lengths->push_back(plengths); // // cur_time += length; // state = arc.nextstate; // } // } // } // // // // void CompactLatticeShortestPath(const CompactLattice &clat, // CompactLattice *shortest_path) { // using namespace fst; // if (clat.Properties(fst::kTopSorted, true) == 0) { // CompactLattice clat_copy(clat); // if (!TopSort(&clat_copy)) // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // CompactLatticeShortestPath(clat_copy, shortest_path); // return; // } // // Now we can assume it's topologically sorted. // shortest_path->DeleteStates(); // if (clat.Start() == kNoStateId) return; // typedef CompactLatticeArc Arc; // typedef Arc::StateId StateId; // typedef CompactLatticeWeight Weight; // vector > best_cost_and_pred(clat.NumStates() + 1); // StateId superfinal = clat.NumStates(); // for (StateId s = 0; s <= clat.NumStates(); s++) { // best_cost_and_pred[s].first = std::numeric_limits::infinity(); // best_cost_and_pred[s].second = fst::kNoStateId; // } // best_cost_and_pred[clat.Start()].first = 0; // for (StateId s = 0; s < clat.NumStates(); s++) { // double my_cost = best_cost_and_pred[s].first; // for (ArcIterator aiter(clat, s); // !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // double arc_cost = ConvertToCost(arc.weight), // next_cost = my_cost + arc_cost; // if (next_cost < best_cost_and_pred[arc.nextstate].first) { // best_cost_and_pred[arc.nextstate].first = next_cost; // best_cost_and_pred[arc.nextstate].second = s; // } // } // double final_cost = ConvertToCost(clat.Final(s)), // tot_final = my_cost + final_cost; // if (tot_final < best_cost_and_pred[superfinal].first) { // best_cost_and_pred[superfinal].first = tot_final; // best_cost_and_pred[superfinal].second = s; // } // } // std::vector states; // states on best path. // StateId cur_state = superfinal, start_state = clat.Start(); // while (cur_state != start_state) { // StateId prev_state = best_cost_and_pred[cur_state].second; // if (prev_state == kNoStateId) { // KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)"; // return; // return empty best-path. // } // states.push_back(prev_state); // KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles"); // cur_state = prev_state; // } // std::reverse(states.begin(), states.end()); // for (size_t i = 0; i < states.size(); i++) // shortest_path->AddState(); // for (StateId s = 0; static_cast(s) < states.size(); s++) { // if (s == 0) shortest_path->SetStart(s); // if (static_cast(s + 1) < states.size()) { // transition to next state. // bool have_arc = false; // Arc cur_arc; // for (ArcIterator aiter(clat, states[s]); // !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // if (arc.nextstate == states[s+1]) { // if (!have_arc || // ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) { // cur_arc = arc; // have_arc = true; // } // } // } // KALDI_ASSERT(have_arc && "Code error."); // shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel, // cur_arc.weight, s+1)); // } else { // final-prob. // shortest_path->SetFinal(s, clat.Final(states[s])); // } // } // } // // // void ExpandCompactLattice(const CompactLattice &clat, // double epsilon, // CompactLattice *expand_clat) { // using namespace fst; // typedef CompactLattice::Arc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // typedef std::pair StatePair; // typedef unordered_map > MapType; // typedef MapType::iterator IterType; // // if (clat.Start() == kNoStateId) return; // // Make sure the input lattice is topologically sorted. // if (clat.Properties(kTopSorted, true) == 0) { // CompactLattice clat_copy(clat); // KALDI_LOG << "Topsort this lattice."; // if (!TopSort(&clat_copy)) // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // ExpandCompactLattice(clat_copy, epsilon, expand_clat); // return; // } // // // Compute backward logprobs betas for the expanded lattice. // // Note: the backward logprobs in the original lattice and the // // expanded lattice are the same. // int32 num_states = clat.NumStates(); // std::vector beta(num_states, kLogZeroDouble); // ComputeCompactLatticeBetas(clat, &beta); // double tot_backward_logprob = beta[0]; // std::vector alpha; // alpha.push_back(0.0); // expand_clat->DeleteStates(); // MapType state_map; // Map from state pair (orig_state, copy_state) to // // copy_state, where orig_state is a state in the original lattice, and // // copy_state is its corresponding one in the expanded lattice. // unordered_map states; // Map from orig_state to its // // copy_state for states with incoming arcs' posteriors <= epsilon. // std::queue state_queue; // // // Set start state in the expanded lattice. // StateId start_state = expand_clat->AddState(); // expand_clat->SetStart(start_state); // StatePair start_pair(clat.Start(), start_state); // state_queue.push(start_pair); // std::pair result = // state_map.insert(std::make_pair(start_pair, start_state)); // KALDI_ASSERT(result.second == true); // // // Expand and update forward logprobs alphas in . // while (!state_queue.empty()) { // StatePair s = state_queue.front(); // StateId s1 = s.first, // s2 = s.second; // state_queue.pop(); // // Weight f = clat.Final(s1); // if (f != Weight::Zero()) { // KALDI_ASSERT(state_map.find(s) != state_map.end()); // expand_clat->SetFinal(state_map[s], f); // } // // for (ArcIterator aiter(clat, s1); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // StateId orig_state = arc.nextstate; // double arc_like = -ConvertToCost(arc.weight), // this_alpha = alpha[s2] + arc_like, // arc_post = Exp(this_alpha + beta[orig_state] - // tot_backward_logprob); // // Generate the expanded lattice. // StateId copy_state; // if (arc_post > epsilon) { // copy_state = expand_clat->AddState(); // StatePair next_pair(orig_state, copy_state); // std::pair result = // state_map.insert(std::make_pair(next_pair, copy_state)); // KALDI_ASSERT(result.second == true); // state_queue.push(next_pair); // } else { // unordered_map::iterator iter = states.find(orig_state); // if (iter == states.end() ) { // The counterpart state of orig_state // // has not been created in yet. // copy_state = expand_clat->AddState(); // StatePair next_pair(orig_state, copy_state); // std::pair result = // state_map.insert(std::make_pair(next_pair, copy_state)); // KALDI_ASSERT(result.second == true); // state_queue.push(next_pair); // states[orig_state] = copy_state; // } else { // copy_state = iter->second; // } // } // // Create an arc from state_map[s] to copy_state in the expanded lattice. // expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight, // copy_state)); // // Compute forward logprobs alpha for the expanded lattice. // if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha // // for copy_state in . // alpha.push_back(this_alpha); // } else { // Accumulate alpha. // alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha); // } // } // } // end while // } // // // void CompactLatticeBestCostsAndTracebacks( // const CompactLattice &clat, // CostTraceType *forward_best_cost_and_pred, // CostTraceType *backward_best_cost_and_pred) { // // // typedef the arc, weight types // typedef CompactLatticeArc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // // forward_best_cost_and_pred->clear(); // backward_best_cost_and_pred->clear(); // forward_best_cost_and_pred->resize(clat.NumStates()); // backward_best_cost_and_pred->resize(clat.NumStates()); // // Initialize the cost and predecessor state for each state. // for (StateId s = 0; s < clat.NumStates(); s++) { // (*forward_best_cost_and_pred)[s].first = // std::numeric_limits::infinity(); // (*backward_best_cost_and_pred)[s].first = // std::numeric_limits::infinity(); // (*forward_best_cost_and_pred)[s].second = fst::kNoStateId; // (*backward_best_cost_and_pred)[s].second = fst::kNoStateId; // } // // StateId start_state = clat.Start(); // (*forward_best_cost_and_pred)[start_state].first = 0; // // Transverse the lattice forwardly to compute the best cost from the start // // state to each state and the best predecessor state of each state. // for (StateId s = 0; s < clat.NumStates(); s++) { // double cur_cost = (*forward_best_cost_and_pred)[s].first; // for (fst::ArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double next_cost = cur_cost + ConvertToCost(arc.weight); // if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) { // (*forward_best_cost_and_pred)[arc.nextstate].first = next_cost; // (*forward_best_cost_and_pred)[arc.nextstate].second = s; // } // } // } // // Transverse the lattice backwardly to compute the best cost from a final // // state to each state and the best predecessor state of each state. // for (StateId s = clat.NumStates() - 1; s >= 0; s--) { // double this_cost = ConvertToCost(clat.Final(s)); // for (fst::ArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first + // ConvertToCost(arc.weight); // if (next_cost < this_cost) { // this_cost = next_cost; // (*backward_best_cost_and_pred)[s].second = arc.nextstate; // } // } // (*backward_best_cost_and_pred)[s].first = this_cost; // } // } // // // void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores, // CompactLattice *clat) { // if (clat->Start() == fst::kNoStateId) return; // // Make sure the input lattice is topologically sorted. // if (clat->Properties(fst::kTopSorted, true) == 0) { // KALDI_LOG << "Topsort this lattice."; // if (!TopSort(clat)) // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // AddNnlmScoreToCompactLattice(nnlm_scores, clat); // return; // } // // // typedef the arc, weight types // typedef CompactLatticeArc Arc; // typedef Arc::Weight Weight; // typedef Arc::StateId StateId; // typedef std::pair StatePair; // // int32 num_states = clat->NumStates(); // unordered_map > final_state_check; // for (StateId s = 0; s < num_states; s++) { // for (fst::MutableArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // Arc arc(aiter.Value()); // StatePair arc_index = std::make_pair(static_cast(s), // static_cast(arc.nextstate)); // MapT::const_iterator it = nnlm_scores.find(arc_index); // double nnlm_score; // if (it != nnlm_scores.end()) // nnlm_score = it->second; // else // KALDI_ERR << "Some arc does not have neural language model score."; // if (arc.ilabel != 0) { // if there is a word on this arc // LatticeWeight weight = arc.weight.Weight(); // // Add associated neural LM score to each arc. // weight.SetValue1(weight.Value1() + nnlm_score); // arc.weight.SetWeight(weight); // aiter.SetValue(arc); // } // Weight clat_final = clat->Final(arc.nextstate); // StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate); // // Add neural LM scores to each final state only once. // if (clat_final != CompactLatticeWeight::Zero() && // final_state_check.find(final_pair) == final_state_check.end()) { // MapT::const_iterator final_it = nnlm_scores.find(final_pair); // double final_nnlm_score = 0.0; // if (final_it != nnlm_scores.end()) // final_nnlm_score = final_it->second; // // Add neural LM scores to the final weight. // Weight final_weight(LatticeWeight(clat_final.Weight().Value1() + // final_nnlm_score, // clat_final.Weight().Value2()), // clat_final.String()); // clat->SetFinal(arc.nextstate, final_weight); // final_state_check[final_pair] = true; // } // } // end looping over arcs // } // end looping over states // } // // void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty, // CompactLattice *clat) { // typedef CompactLatticeArc Arc; // int32 num_states = clat->NumStates(); // // //scan the lattice // for (int32 state = 0; state < num_states; state++) { // for (fst::MutableArcIterator aiter(clat, state); // !aiter.Done(); aiter.Next()) { // // Arc arc(aiter.Value()); // // if (arc.ilabel != 0) { // if there is a word on this arc // LatticeWeight weight = arc.weight.Weight(); // // add word insertion penalty to lattice // weight.SetValue1( weight.Value1() + word_ins_penalty); // arc.weight.SetWeight(weight); // aiter.SetValue(arc); // } // } // end looping over arcs // } // end looping over states // } // // struct ClatRescoreTuple { // ClatRescoreTuple(int32 state, int32 arc, int32 tid): // state_id(state), arc_id(arc), tid(tid) { } // int32 state_id; // int32 arc_id; // int32 tid; // }; // // /** RescoreCompactLatticeInternal is the internal code for both // RescoreCompactLattice and RescoreCompatLatticeSpeedup. For // RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0. // */ // bool RescoreCompactLatticeInternal( // const TransitionModel *tmodel, // BaseFloat speedup_factor, // DecodableInterface *decodable, // CompactLattice *clat) { // KALDI_ASSERT(speedup_factor >= 1.0); // if (clat->NumStates() == 0) { // KALDI_WARN << "Rescoring empty lattice"; // return false; // } // if (!clat->Properties(fst::kTopSorted, true)) { // if (fst::TopSort(clat) == false) { // KALDI_WARN << "Cycles detected in lattice."; // return false; // } // } // std::vector state_times; // int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times); // // std::vector > time_to_state(utt_len); // // int32 num_states = clat->NumStates(); // KALDI_ASSERT(num_states == state_times.size()); // for (size_t state = 0; state < num_states; state++) { // KALDI_ASSERT(state_times[state] >= 0); // int32 t = state_times[state]; // int32 arc_id = 0; // for (fst::MutableArcIterator aiter(clat, state); // !aiter.Done(); aiter.Next(), arc_id++) { // CompactLatticeArc arc = aiter.Value(); // std::vector arc_string = arc.weight.String(); // // for (size_t offset = 0; offset < arc_string.size(); offset++) { // if (t < utt_len) { // end state may be past this.. // int32 tid = arc_string[offset]; // time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid)); // } else { // if (t != utt_len) { // KALDI_WARN << "There appears to be lattice/feature mismatch, " // << "aborting."; // return false; // } // } // } // } // if (clat->Final(state) != CompactLatticeWeight::Zero()) { // arc_id = -1; // std::vector arc_string = clat->Final(state).String(); // for (size_t offset = 0; offset < arc_string.size(); offset++) { // KALDI_ASSERT(t + offset < utt_len); // already checked in // // CompactLatticeStateTimes, so would be code error. // time_to_state[t+offset].push_back( // ClatRescoreTuple(state, arc_id, arc_string[offset])); // } // } // } // // for (int32 t = 0; t < utt_len; t++) { // if ((t < utt_len - 1) && decodable->IsLastFrame(t)) { // KALDI_WARN << "Features are too short for lattice: utt-len is " // << utt_len << ", " << t << " is last frame"; // return false; // } // // frame_scale is the scale we put on the computed acoustic probs for this // // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing // // the "speedup" code). For frames with multiple pdf-ids it will be one. // // For frames with only one pdf-id, it will equal speedup_factor (>=1.0) // // with probability 1.0 / speedup_factor, and zero otherwise. If it is zero, // // we can avoid computing the probabilities. // BaseFloat frame_scale = 1.0; // KALDI_ASSERT(!time_to_state[t].empty()); // if (tmodel != NULL) { // int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid); // bool frame_has_multiple_pdfs = false; // for (size_t i = 1; i < time_to_state[t].size(); i++) { // if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) { // frame_has_multiple_pdfs = true; // break; // } // } // if (frame_has_multiple_pdfs) { // frame_scale = 1.0; // } else { // if (WithProb(1.0 / speedup_factor)) { // frame_scale = speedup_factor; // } else { // frame_scale = 0.0; // } // } // if (frame_scale == 0.0) // continue; // the code below would be pointless. // } // // for (size_t i = 0; i < time_to_state[t].size(); i++) { // int32 state = time_to_state[t][i].state_id; // int32 arc_id = time_to_state[t][i].arc_id; // int32 tid = time_to_state[t][i].tid; // // if (arc_id == -1) { // Final state // // Access the trans_id // CompactLatticeWeight curr_clat_weight = clat->Final(state); // // // Calculate likelihood // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale; // // update weight // CompactLatticeWeight new_clat_weight = curr_clat_weight; // LatticeWeight new_lat_weight = new_clat_weight.Weight(); // new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2()); // new_clat_weight.SetWeight(new_lat_weight); // clat->SetFinal(state, new_clat_weight); // } else { // fst::MutableArcIterator aiter(clat, state); // // aiter.Seek(arc_id); // CompactLatticeArc arc = aiter.Value(); // // // Calculate likelihood // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale; // // update weight // LatticeWeight new_weight = arc.weight.Weight(); // new_weight.SetValue2(-log_like + arc.weight.Weight().Value2()); // arc.weight.SetWeight(new_weight); // aiter.SetValue(arc); // } // } // } // return true; // } // // // bool RescoreCompactLatticeSpeedup( // const TransitionModel &tmodel, // BaseFloat speedup_factor, // DecodableInterface *decodable, // CompactLattice *clat) { // return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat); // } // // bool RescoreCompactLattice(DecodableInterface *decodable, // CompactLattice *clat) { // return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat); // } // // // bool RescoreLattice(DecodableInterface *decodable, // Lattice *lat) { // if (lat->NumStates() == 0) { // KALDI_WARN << "Rescoring empty lattice"; // return false; // } // if (!lat->Properties(fst::kTopSorted, true)) { // if (fst::TopSort(lat) == false) { // KALDI_WARN << "Cycles detected in lattice."; // return false; // } // } // std::vector state_times; // int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times); // // std::vector > time_to_state(utt_len ); // // int32 num_states = lat->NumStates(); // KALDI_ASSERT(num_states == state_times.size()); // for (size_t state = 0; state < num_states; state++) { // int32 t = state_times[state]; // // Don't check t >= 0 because non-accessible states could have t = -1. // KALDI_ASSERT(t <= utt_len); // if (t >= 0 && t < utt_len) // time_to_state[t].push_back(state); // } // // for (int32 t = 0; t < utt_len; t++) { // if ((t < utt_len - 1) && decodable->IsLastFrame(t)) { // KALDI_WARN << "Features are too short for lattice: utt-len is " // << utt_len << ", " << t << " is last frame"; // return false; // } // for (size_t i = 0; i < time_to_state[t].size(); i++) { // int32 state = time_to_state[t][i]; // for (fst::MutableArcIterator aiter(lat, state); // !aiter.Done(); aiter.Next()) { // LatticeArc arc = aiter.Value(); // if (arc.ilabel != 0) { // int32 trans_id = arc.ilabel; // Note: it doesn't necessarily // // have to be a transition-id, just whatever the Decodable // // object is expecting, but it's normally a transition-id. // // BaseFloat log_like = decodable->LogLikelihood(t, trans_id); // arc.weight.SetValue2(-log_like + arc.weight.Value2()); // aiter.SetValue(arc); // } // } // } // } // return true; // } // // // BaseFloat LatticeForwardBackwardMmi( // const TransitionModel &tmodel, // const Lattice &lat, // const std::vector &num_ali, // bool drop_frames, // bool convert_to_pdf_ids, // bool cancel, // Posterior *post) { // // First compute the MMI posteriors. // // Posterior den_post; // BaseFloat ans = LatticeForwardBackward(lat, // &den_post, // NULL); // // Posterior num_post; // AlignmentToPosterior(num_ali, &num_post); // // // Now negate the MMI posteriors and add the numerator // // posteriors. // ScalePosterior(-1.0, &den_post); // // if (convert_to_pdf_ids) { // Posterior num_tmp; // ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp); // num_tmp.swap(num_post); // Posterior den_tmp; // ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp); // den_tmp.swap(den_post); // } // // MergePosteriors(num_post, den_post, // cancel, drop_frames, post); // // return ans; // } // // // int32 LongestSentenceLength(const Lattice &lat) { // typedef Lattice::Arc Arc; // typedef Arc::Label Label; // typedef Arc::StateId StateId; // // if (lat.Properties(fst::kTopSorted, true) == 0) { // Lattice lat_copy(lat); // if (!TopSort(&lat_copy)) // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // return LongestSentenceLength(lat_copy); // } // std::vector max_length(lat.NumStates(), 0); // int32 lattice_max_length = 0; // for (StateId s = 0; s < lat.NumStates(); s++) { // int32 this_max_length = max_length[s]; // for (fst::ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // bool arc_has_word = (arc.olabel != 0); // StateId nextstate = arc.nextstate; // KALDI_ASSERT(static_cast(nextstate) < max_length.size()); // if (arc_has_word) { // // A lattice should ideally not have cycles anyway; a cycle with a word // // on is something very bad. // KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on."); // max_length[nextstate] = std::max(max_length[nextstate], // this_max_length + 1); // } else { // max_length[nextstate] = std::max(max_length[nextstate], // this_max_length); // } // } // if (lat.Final(s) != LatticeWeight::Zero()) // lattice_max_length = std::max(lattice_max_length, max_length[s]); // } // return lattice_max_length; // } // // int32 LongestSentenceLength(const CompactLattice &clat) { // typedef CompactLattice::Arc Arc; // typedef Arc::Label Label; // typedef Arc::StateId StateId; // // if (clat.Properties(fst::kTopSorted, true) == 0) { // CompactLattice clat_copy(clat); // if (!TopSort(&clat_copy)) // KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)"; // return LongestSentenceLength(clat_copy); // } // std::vector max_length(clat.NumStates(), 0); // int32 lattice_max_length = 0; // for (StateId s = 0; s < clat.NumStates(); s++) { // int32 this_max_length = max_length[s]; // for (fst::ArcIterator aiter(clat, s); // !aiter.Done(); aiter.Next()) { // const Arc &arc = aiter.Value(); // bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel. // // also note: for normal CompactLattice, e.g. as produced by // // determinization, all arcs will have nonzero labels, but the user might // // decide to remplace some of the labels with zero for some reason, and we // // want to support this. // StateId nextstate = arc.nextstate; // KALDI_ASSERT(static_cast(nextstate) < max_length.size()); // KALDI_ASSERT(nextstate > s && "CompactLattice has cycles"); // if (arc_has_word) // max_length[nextstate] = std::max(max_length[nextstate], // this_max_length + 1); // else // max_length[nextstate] = std::max(max_length[nextstate], // this_max_length); // } // if (clat.Final(s) != CompactLatticeWeight::Zero()) // lattice_max_length = std::max(lattice_max_length, max_length[s]); // } // return lattice_max_length; // } // // void ComposeCompactLatticeDeterministic( // const CompactLattice& clat, // fst::DeterministicOnDemandFst* det_fst, // CompactLattice* composed_clat) { // // StdFst::Arc and CompactLatticeArc has the same StateId type. // typedef fst::StdArc::StateId StateId; // typedef fst::StdArc::Weight Weight1; // typedef CompactLatticeArc::Weight Weight2; // typedef std::pair StatePair; // typedef unordered_map > MapType; // typedef MapType::iterator IterType; // // // Empties the output FST. // KALDI_ASSERT(composed_clat != NULL); // composed_clat->DeleteStates(); // // MapType state_map; // std::queue state_queue; // // // Sets start state in . // StateId start_state = composed_clat->AddState(); // StatePair start_pair(clat.Start(), det_fst->Start()); // composed_clat->SetStart(start_state); // state_queue.push(start_pair); // std::pair result = // state_map.insert(std::make_pair(start_pair, start_state)); // KALDI_ASSERT(result.second == true); // // // Starts composition here. // while (!state_queue.empty()) { // // Gets the first state in the queue. // StatePair s = state_queue.front(); // StateId s1 = s.first; // StateId s2 = s.second; // state_queue.pop(); // // // Weight2 clat_final = clat.Final(s1); // if (clat_final.Weight().Value1() != // std::numeric_limits::infinity()) { // // Test for whether the final-prob of state s1 was zero. // Weight1 det_fst_final = det_fst->Final(s2); // if (det_fst_final.Value() != // std::numeric_limits::infinity()) { // // Test for whether the final-prob of state s2 was zero. If neither // // source-state final prob was zero, then we should create final state // // in fst_composed. We compute the product manually since this is more // // efficient. // Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() + // det_fst_final.Value(), // clat_final.Weight().Value2()), // clat_final.String()); // // we can assume final_weight is not Zero(), since neither of // // the sources was zero. // KALDI_ASSERT(state_map.find(s) != state_map.end()); // composed_clat->SetFinal(state_map[s], final_weight); // } // } // // // Loops over pair of edges at s1 and s2. // for (fst::ArcIterator aiter(clat, s1); // !aiter.Done(); aiter.Next()) { // const CompactLatticeArc& arc1 = aiter.Value(); // fst::StdArc arc2; // StateId next_state1 = arc1.nextstate, next_state2; // bool matched = false; // // if (arc1.olabel == 0) { // // If the symbol on is , we transit to the next state // // for , but keep at the current state. // matched = true; // next_state2 = s2; // } else { // // Otherwise try to find the matched arc in . // matched = det_fst->GetArc(s2, arc1.olabel, &arc2); // if (matched) { // next_state2 = arc2.nextstate; // } // } // // // If matched arc is found in , then we have to add new arcs to // // . // if (matched) { // StatePair next_state_pair(next_state1, next_state2); // IterType siter = state_map.find(next_state_pair); // StateId next_state; // // // Adds composed state to . // if (siter == state_map.end()) { // // If the composed state has not been created yet, create it. // next_state = composed_clat->AddState(); // std::pair next_state_map(next_state_pair, // next_state); // std::pair result = state_map.insert(next_state_map); // KALDI_ASSERT(result.second); // state_queue.push(next_state_pair); // } else { // // If the composed state is already in , we can directly // // use that. // next_state = siter->second; // } // // // Adds arc to . // if (arc1.olabel == 0) { // composed_clat->AddArc(state_map[s], // CompactLatticeArc(arc1.ilabel, 0, // arc1.weight, next_state)); // } else { // Weight2 composed_weight( // LatticeWeight(arc1.weight.Weight().Value1() + // arc2.weight.Value(), // arc1.weight.Weight().Value2()), // arc1.weight.String()); // composed_clat->AddArc(state_map[s], // CompactLatticeArc(arc1.ilabel, arc2.olabel, // composed_weight, next_state)); // } // } // } // } // fst::Connect(composed_clat); // } // // // void ComputeAcousticScoresMap( // const Lattice &lat, // unordered_map, std::pair, // PairHasher > *acoustic_scores) { // // typedef the arc, weight types // typedef Lattice::Arc Arc; // typedef Arc::Weight LatticeWeight; // typedef Arc::StateId StateId; // // acoustic_scores->clear(); // // std::vector state_times; // LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted // // KALDI_ASSERT(lat.Start() == 0); // // for (StateId s = 0; s < lat.NumStates(); s++) { // int32 t = state_times[s]; // for (fst::ArcIterator aiter(lat, s); !aiter.Done(); // aiter.Next()) { // const Arc &arc = aiter.Value(); // const LatticeWeight &weight = arc.weight; // // int32 tid = arc.ilabel; // // if (tid != 0) { // unordered_map, std::pair, // PairHasher >::iterator it = acoustic_scores->find(std::make_pair(t, tid)); // if (it == acoustic_scores->end()) { // acoustic_scores->insert(std::make_pair(std::make_pair(t, tid), // std::make_pair(weight.Value2(), 1))); // } else { // if (it->second.second == 2 // && it->second.first / it->second.second != weight.Value2()) { // KALDI_VLOG(2) << "Transitions on the same frame have different " // << "acoustic costs for tid " << tid << "; " // << it->second.first / it->second.second // << " vs " << weight.Value2(); // } // it->second.first += weight.Value2(); // it->second.second++; // } // } else { // // Arcs with epsilon input label (tid) must have 0 acoustic cost // KALDI_ASSERT(weight.Value2() == 0); // } // } // // LatticeWeight f = lat.Final(s); // if (f != LatticeWeight::Zero()) { // // Final acoustic cost must be 0 as we are reading from // // non-determinized, non-compact lattice // KALDI_ASSERT(f.Value2() == 0.0); // } // } // } // // void ReplaceAcousticScoresFromMap( // const unordered_map, std::pair, // PairHasher > &acoustic_scores, // Lattice *lat) { // // typedef the arc, weight types // typedef Lattice::Arc Arc; // typedef Arc::Weight LatticeWeight; // typedef Arc::StateId StateId; // // TopSortLatticeIfNeeded(lat); // // std::vector state_times; // LatticeStateTimes(*lat, &state_times); // // KALDI_ASSERT(lat->Start() == 0); // // for (StateId s = 0; s < lat->NumStates(); s++) { // int32 t = state_times[s]; // for (fst::MutableArcIterator aiter(lat, s); // !aiter.Done(); aiter.Next()) { // Arc arc(aiter.Value()); // // int32 tid = arc.ilabel; // if (tid != 0) { // unordered_map, std::pair, // PairHasher >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid)); // if (it == acoustic_scores.end()) { // KALDI_ERR << "Could not find tid " << tid << " at time " << t // << " in the acoustic scores map."; // } else { // arc.weight.SetValue2(it->second.first / it->second.second); // } // } else { // // For epsilon arcs, set acoustic cost to 0.0 // arc.weight.SetValue2(0.0); // } // aiter.SetValue(arc); // } // // LatticeWeight f = lat->Final(s); // if (f != LatticeWeight::Zero()) { // // Set final acoustic cost to 0.0 // f.SetValue2(0.0); // lat->SetFinal(s, f); // } // } // } } // namespace kaldi