You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1266 lines
47 KiB
1266 lines
47 KiB
3 years ago
|
// fstext/fstext-utils-inl.h
|
||
|
|
||
|
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
|
||
|
// Daniel Povey)
|
||
|
// 2014 Telepoint Global Hosting Service, LLC. (Author: David
|
||
|
// Snyder)
|
||
|
|
||
|
// See ../../COPYING for clarification regarding multiple authors
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||
|
// See the Apache 2 License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#ifndef KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
|
||
|
#define KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
|
||
|
|
||
|
#include <algorithm>
|
||
|
#include <cstring>
|
||
|
#include <map>
|
||
|
#include <set>
|
||
|
#include <sstream>
|
||
|
#include <string>
|
||
|
#include <unordered_map>
|
||
|
#include <unordered_set>
|
||
|
#include <utility>
|
||
|
#include <vector>
|
||
|
|
||
|
#include "base/kaldi-common.h"
|
||
|
#include "fstext/determinize-star.h"
|
||
|
#include "fstext/pre-determinize.h"
|
||
|
#include "util/const-integer-set.h"
|
||
|
#include "util/kaldi-io.h"
|
||
|
#include "util/stl-utils.h"
|
||
|
#include "util/text-utils.h"
|
||
|
|
||
|
namespace fst {
|
||
|
|
||
|
template <class Arc>
|
||
|
typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc> &fst) {
|
||
|
typename Arc::Label ans = 0;
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
typename Arc::StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
ans = std::max(ans, arc.olabel);
|
||
|
}
|
||
|
}
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc> &fst) {
|
||
|
typename Arc::Label ans = 0;
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
typename Arc::StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
ans = std::max(ans, arc.ilabel);
|
||
|
}
|
||
|
}
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
typename Arc::StateId NumArcs(const ExpandedFst<Arc> &fst) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
StateId num_arcs = 0;
|
||
|
for (StateId s = 0; s < fst.NumStates(); s++) num_arcs += fst.NumArcs(s);
|
||
|
return num_arcs;
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void GetOutputSymbols(const Fst<Arc> &fst, bool include_eps,
|
||
|
std::vector<I> *symbols) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
std::set<I> all_syms;
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
typename Arc::StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
all_syms.insert(arc.olabel);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Remove epsilon, if instructed.
|
||
|
if (!include_eps && !all_syms.empty() && *all_syms.begin() == 0)
|
||
|
all_syms.erase(0);
|
||
|
KALDI_ASSERT(symbols != NULL);
|
||
|
kaldi::CopySetToVector(all_syms, symbols);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void GetInputSymbols(const Fst<Arc> &fst, bool include_eps,
|
||
|
std::vector<I> *symbols) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
unordered_set<I> all_syms;
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
typename Arc::StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
all_syms.insert(arc.ilabel);
|
||
|
}
|
||
|
}
|
||
|
// Remove epsilon, if instructed.
|
||
|
if (!include_eps && all_syms.count(0) != 0) all_syms.erase(0);
|
||
|
KALDI_ASSERT(symbols != NULL);
|
||
|
kaldi::CopySetToVector(all_syms, symbols);
|
||
|
std::sort(symbols->begin(), symbols->end());
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
class RemoveSomeInputSymbolsMapper {
|
||
|
public:
|
||
|
Arc operator()(const Arc &arc_in) {
|
||
|
Arc ans = arc_in;
|
||
|
if (to_remove_set_.count(ans.ilabel) != 0)
|
||
|
ans.ilabel = 0; // remove this symbol
|
||
|
return ans;
|
||
|
}
|
||
|
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
|
||
|
MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; }
|
||
|
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
|
||
|
uint64 Properties(uint64 props) const {
|
||
|
// remove the following as we don't know now if any of them are true.
|
||
|
uint64 to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
|
||
|
kNonIDeterministic | kNoEpsilons | kNoIEpsilons |
|
||
|
kILabelSorted | kNotILabelSorted;
|
||
|
return props & ~to_remove;
|
||
|
}
|
||
|
explicit RemoveSomeInputSymbolsMapper(const std::vector<I> &to_remove)
|
||
|
: to_remove_set_(to_remove) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
assert(to_remove_set_.count(0) == 0); // makes no sense to remove epsilon.
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
kaldi::ConstIntegerSet<I> to_remove_set_;
|
||
|
};
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
using LookaheadFst = ArcMapFst<Arc, Arc, RemoveSomeInputSymbolsMapper<Arc, I> >;
|
||
|
|
||
|
// Lookahead composition is used for optimized online
|
||
|
// composition of FSTs during decoding. See
|
||
|
// nnet3/nnet3-latgen-faster-lookahead.cc. For details of compose filters
|
||
|
// see DefaultLookAhead in fst/compose.h
|
||
|
template <class Arc, class I>
|
||
|
LookaheadFst<Arc, I> *LookaheadComposeFst(const Fst<Arc> &ifst1,
|
||
|
const Fst<Arc> &ifst2,
|
||
|
const std::vector<I> &to_remove) {
|
||
|
fst::CacheOptions cache_opts(true, 1 << 25LL);
|
||
|
fst::CacheOptions cache_opts_map(true, 0);
|
||
|
fst::ArcMapFstOptions arcmap_opts(cache_opts);
|
||
|
RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
|
||
|
return new LookaheadFst<Arc, I>(ComposeFst<Arc>(ifst1, ifst2, cache_opts),
|
||
|
mapper, arcmap_opts);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void RemoveSomeInputSymbols(const std::vector<I> &to_remove,
|
||
|
MutableFst<Arc> *fst) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
|
||
|
Map(fst, mapper);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
class MapInputSymbolsMapper {
|
||
|
public:
|
||
|
Arc operator()(const Arc &arc_in) {
|
||
|
Arc ans = arc_in;
|
||
|
if (ans.ilabel > 0 && ans.ilabel < static_cast<typename Arc::Label>(
|
||
|
(*symbol_mapping_).size()))
|
||
|
ans.ilabel = (*symbol_mapping_)[ans.ilabel];
|
||
|
return ans;
|
||
|
}
|
||
|
MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
|
||
|
MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
|
||
|
MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
|
||
|
uint64 Properties(uint64 props) const { // Not tested.
|
||
|
bool remove_epsilons =
|
||
|
(symbol_mapping_->size() > 0 && (*symbol_mapping_)[0] != 0);
|
||
|
bool add_epsilons = (symbol_mapping_->size() > 1 &&
|
||
|
*std::min_element(symbol_mapping_->begin() + 1,
|
||
|
symbol_mapping_->end()) == 0);
|
||
|
|
||
|
// remove the following as we don't know now if any of them are true.
|
||
|
uint64 props_to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
|
||
|
kNonIDeterministic | kILabelSorted |
|
||
|
kNotILabelSorted;
|
||
|
if (remove_epsilons) props_to_remove |= kEpsilons | kIEpsilons;
|
||
|
if (add_epsilons) props_to_remove |= kNoEpsilons | kNoIEpsilons;
|
||
|
uint64 props_to_add = 0;
|
||
|
if (remove_epsilons && !add_epsilons)
|
||
|
props_to_add |= kNoEpsilons | kNoIEpsilons;
|
||
|
return (props & ~props_to_remove) | props_to_add;
|
||
|
}
|
||
|
// initialize with copy = false only if the "to_remove" argument will not be
|
||
|
// deleted in the lifetime of this object.
|
||
|
MapInputSymbolsMapper(const std::vector<I> &to_remove, bool copy) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
if (copy)
|
||
|
symbol_mapping_ = new std::vector<I>(to_remove);
|
||
|
else
|
||
|
symbol_mapping_ = &to_remove;
|
||
|
owned = copy;
|
||
|
}
|
||
|
~MapInputSymbolsMapper() {
|
||
|
if (owned && symbol_mapping_ != NULL) delete symbol_mapping_;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
bool owned;
|
||
|
const std::vector<I> *symbol_mapping_;
|
||
|
};
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void MapInputSymbols(const std::vector<I> &symbol_mapping,
|
||
|
MutableFst<Arc> *fst) {
|
||
|
KALDI_ASSERT_IS_INTEGER_TYPE(I);
|
||
|
// false == don't copy the "symbol_mapping", retain pointer--
|
||
|
// safe since short-lived object.
|
||
|
MapInputSymbolsMapper<Arc, I> mapper(symbol_mapping, false);
|
||
|
Map(fst, mapper);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
bool GetLinearSymbolSequence(const Fst<Arc> &fst, std::vector<I> *isymbols_out,
|
||
|
std::vector<I> *osymbols_out,
|
||
|
typename Arc::Weight *tot_weight_out) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
Weight tot_weight = Weight::One();
|
||
|
std::vector<I> ilabel_seq;
|
||
|
std::vector<I> olabel_seq;
|
||
|
|
||
|
StateId cur_state = fst.Start();
|
||
|
if (cur_state == kNoStateId) { // empty sequence.
|
||
|
if (isymbols_out != NULL) isymbols_out->clear();
|
||
|
if (osymbols_out != NULL) osymbols_out->clear();
|
||
|
if (tot_weight_out != NULL) *tot_weight_out = Weight::Zero();
|
||
|
return true;
|
||
|
}
|
||
|
while (1) {
|
||
|
Weight w = fst.Final(cur_state);
|
||
|
if (w != Weight::Zero()) { // is final..
|
||
|
tot_weight = Times(w, tot_weight);
|
||
|
if (fst.NumArcs(cur_state) != 0) return false;
|
||
|
if (isymbols_out != NULL) *isymbols_out = ilabel_seq;
|
||
|
if (osymbols_out != NULL) *osymbols_out = olabel_seq;
|
||
|
if (tot_weight_out != NULL) *tot_weight_out = tot_weight;
|
||
|
return true;
|
||
|
} else {
|
||
|
if (fst.NumArcs(cur_state) != 1) return false;
|
||
|
|
||
|
ArcIterator<Fst<Arc> > iter(fst, cur_state); // get the only arc.
|
||
|
const Arc &arc = iter.Value();
|
||
|
tot_weight = Times(arc.weight, tot_weight);
|
||
|
if (arc.ilabel != 0) ilabel_seq.push_back(arc.ilabel);
|
||
|
if (arc.olabel != 0) olabel_seq.push_back(arc.olabel);
|
||
|
cur_state = arc.nextstate;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// see fstext-utils.h for comment.
|
||
|
template <class Arc>
|
||
|
void ConvertNbestToVector(const Fst<Arc> &fst,
|
||
|
std::vector<VectorFst<Arc> > *fsts_out) {
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
fsts_out->clear();
|
||
|
StateId start_state = fst.Start();
|
||
|
if (start_state == kNoStateId) return; // No output.
|
||
|
size_t n_arcs = fst.NumArcs(start_state);
|
||
|
bool start_is_final = (fst.Final(start_state) != Weight::Zero());
|
||
|
fsts_out->reserve(n_arcs + (start_is_final ? 1 : 0));
|
||
|
|
||
|
if (start_is_final) {
|
||
|
fsts_out->resize(fsts_out->size() + 1);
|
||
|
StateId start_state_out = fsts_out->back().AddState();
|
||
|
fsts_out->back().SetFinal(start_state_out, fst.Final(start_state));
|
||
|
}
|
||
|
|
||
|
for (ArcIterator<Fst<Arc> > start_aiter(fst, start_state);
|
||
|
!start_aiter.Done(); start_aiter.Next()) {
|
||
|
fsts_out->resize(fsts_out->size() + 1);
|
||
|
VectorFst<Arc> &ofst = fsts_out->back();
|
||
|
const Arc &first_arc = start_aiter.Value();
|
||
|
StateId cur_state = start_state, cur_ostate = ofst.AddState();
|
||
|
ofst.SetStart(cur_ostate);
|
||
|
StateId next_ostate = ofst.AddState();
|
||
|
ofst.AddArc(cur_ostate, Arc(first_arc.ilabel, first_arc.olabel,
|
||
|
first_arc.weight, next_ostate));
|
||
|
cur_state = first_arc.nextstate;
|
||
|
cur_ostate = next_ostate;
|
||
|
while (1) {
|
||
|
size_t this_n_arcs = fst.NumArcs(cur_state);
|
||
|
KALDI_ASSERT(this_n_arcs <= 1); // or it violates our assumptions
|
||
|
// about the input.
|
||
|
if (this_n_arcs == 1) {
|
||
|
KALDI_ASSERT(fst.Final(cur_state) == Weight::Zero());
|
||
|
// or problem with ShortestPath.
|
||
|
ArcIterator<Fst<Arc> > aiter(fst, cur_state);
|
||
|
const Arc &arc = aiter.Value();
|
||
|
next_ostate = ofst.AddState();
|
||
|
ofst.AddArc(cur_ostate,
|
||
|
Arc(arc.ilabel, arc.olabel, arc.weight, next_ostate));
|
||
|
cur_state = arc.nextstate;
|
||
|
cur_ostate = next_ostate;
|
||
|
} else {
|
||
|
KALDI_ASSERT(fst.Final(cur_state) != Weight::Zero());
|
||
|
// or problem with ShortestPath.
|
||
|
ofst.SetFinal(cur_ostate, fst.Final(cur_state));
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// see fstext-utils.sh for comment.
|
||
|
template <class Arc>
|
||
|
void NbestAsFsts(const Fst<Arc> &fst, size_t n,
|
||
|
std::vector<VectorFst<Arc> > *fsts_out) {
|
||
|
KALDI_ASSERT(n > 0);
|
||
|
KALDI_ASSERT(fsts_out != NULL);
|
||
|
VectorFst<Arc> nbest_fst;
|
||
|
ShortestPath(fst, &nbest_fst, n);
|
||
|
ConvertNbestToVector(nbest_fst, fsts_out);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void MakeLinearAcceptorWithAlternatives(
|
||
|
const std::vector<std::vector<I> > &labels, MutableFst<Arc> *ofst) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
ofst->DeleteStates();
|
||
|
StateId cur_state = ofst->AddState();
|
||
|
ofst->SetStart(cur_state);
|
||
|
for (size_t i = 0; i < labels.size(); i++) {
|
||
|
KALDI_ASSERT(labels[i].size() != 0);
|
||
|
StateId next_state = ofst->AddState();
|
||
|
for (size_t j = 0; j < labels[i].size(); j++) {
|
||
|
Arc arc(labels[i][j], labels[i][j], Weight::One(), next_state);
|
||
|
ofst->AddArc(cur_state, arc);
|
||
|
}
|
||
|
cur_state = next_state;
|
||
|
}
|
||
|
ofst->SetFinal(cur_state, Weight::One());
|
||
|
}
|
||
|
|
||
|
template <class Arc, class I>
|
||
|
void MakeLinearAcceptor(const std::vector<I> &labels, MutableFst<Arc> *ofst) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
ofst->DeleteStates();
|
||
|
StateId cur_state = ofst->AddState();
|
||
|
ofst->SetStart(cur_state);
|
||
|
for (size_t i = 0; i < labels.size(); i++) {
|
||
|
StateId next_state = ofst->AddState();
|
||
|
Arc arc(labels[i], labels[i], Weight::One(), next_state);
|
||
|
ofst->AddArc(cur_state, arc);
|
||
|
cur_state = next_state;
|
||
|
}
|
||
|
ofst->SetFinal(cur_state, Weight::One());
|
||
|
}
|
||
|
|
||
|
template <class I>
|
||
|
void GetSymbols(const SymbolTable &symtab, bool include_eps,
|
||
|
std::vector<I> *syms_out) {
|
||
|
KALDI_ASSERT(syms_out != NULL);
|
||
|
syms_out->clear();
|
||
|
for (SymbolTableIterator iter(symtab); !iter.Done(); iter.Next()) {
|
||
|
if (include_eps || iter.Value() != 0) {
|
||
|
syms_out->push_back(iter.Value());
|
||
|
KALDI_ASSERT(syms_out->back() ==
|
||
|
iter.Value()); // an integer-range thing.
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void SafeDeterminizeWrapper(MutableFst<Arc> *ifst, MutableFst<Arc> *ofst,
|
||
|
float delta) {
|
||
|
typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
|
||
|
std::vector<typename Arc::Label> extra_syms;
|
||
|
PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
|
||
|
DeterminizeStar(*ifst, ofst, delta);
|
||
|
RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void SafeDeterminizeMinimizeWrapper(MutableFst<Arc> *ifst, VectorFst<Arc> *ofst,
|
||
|
float delta) {
|
||
|
typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
|
||
|
std::vector<typename Arc::Label> extra_syms;
|
||
|
PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
|
||
|
DeterminizeStar(*ifst, ofst, delta);
|
||
|
RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
|
||
|
RemoveEpsLocal(ofst); // this is "safe" and will never hurt.
|
||
|
MinimizeEncoded(ofst, delta);
|
||
|
}
|
||
|
|
||
|
inline void DeterminizeStarInLog(VectorFst<StdArc> *fst, float delta,
|
||
|
bool *debug_ptr, int max_states) {
|
||
|
// DeterminizeStarInLog determinizes 'fst' in the log semiring, using
|
||
|
// the DeterminizeStar algorithm (which also removes epsilons).
|
||
|
|
||
|
ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
|
||
|
VectorFst<LogArc> *fst_log =
|
||
|
new VectorFst<LogArc>; // Want to determinize in log semiring.
|
||
|
Cast(*fst, fst_log);
|
||
|
VectorFst<StdArc> tmp;
|
||
|
*fst = tmp; // make fst empty to free up memory. [actually may make no
|
||
|
// difference..]
|
||
|
VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
|
||
|
DeterminizeStar(*fst_log, fst_det_log, delta, debug_ptr, max_states);
|
||
|
Cast(*fst_det_log, fst);
|
||
|
delete fst_log;
|
||
|
delete fst_det_log;
|
||
|
}
|
||
|
|
||
|
inline void DeterminizeInLog(VectorFst<StdArc> *fst) {
|
||
|
// DeterminizeInLog determinizes 'fst' in the log semiring.
|
||
|
|
||
|
ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
|
||
|
VectorFst<LogArc> *fst_log =
|
||
|
new VectorFst<LogArc>; // Want to determinize in log semiring.
|
||
|
Cast(*fst, fst_log);
|
||
|
VectorFst<StdArc> tmp;
|
||
|
*fst = tmp; // make fst empty to free up memory. [actually may make no
|
||
|
// difference..]
|
||
|
VectorFst<LogArc> *fst_det_log = new VectorFst<LogArc>;
|
||
|
Determinize(*fst_log, fst_det_log);
|
||
|
Cast(*fst_det_log, fst);
|
||
|
delete fst_log;
|
||
|
delete fst_det_log;
|
||
|
}
|
||
|
|
||
|
// make it inline to avoid having to put it in a .cc file.
|
||
|
// destructive algorithm (changes ifst as well as ofst).
|
||
|
inline void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc> *ifst,
|
||
|
VectorFst<StdArc> *ofst,
|
||
|
float delta) {
|
||
|
VectorFst<LogArc> *ifst_log =
|
||
|
new VectorFst<LogArc>; // Want to determinize in log semiring.
|
||
|
Cast(*ifst, ifst_log);
|
||
|
VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
|
||
|
SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
|
||
|
Cast(*ofst_log, ofst);
|
||
|
delete ifst_log;
|
||
|
delete ofst_log;
|
||
|
RemoveEpsLocal(ofst); // this is "safe" and will never hurt. Do this in
|
||
|
// tropical, which is important.
|
||
|
MinimizeEncoded(ofst, delta); // Non-deterministic minimization will fail in
|
||
|
// log semiring so do it with StdARc.
|
||
|
}
|
||
|
|
||
|
inline void SafeDeterminizeWrapperInLog(VectorFst<StdArc> *ifst,
|
||
|
VectorFst<StdArc> *ofst, float delta) {
|
||
|
VectorFst<LogArc> *ifst_log =
|
||
|
new VectorFst<LogArc>; // Want to determinize in log semiring.
|
||
|
Cast(*ifst, ifst_log);
|
||
|
VectorFst<LogArc> *ofst_log = new VectorFst<LogArc>;
|
||
|
SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
|
||
|
Cast(*ofst_log, ofst);
|
||
|
delete ifst_log;
|
||
|
delete ofst_log;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void RemoveWeights(MutableFst<Arc> *ifst) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
for (StateIterator<MutableFst<Arc> > siter(*ifst); !siter.Done();
|
||
|
siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
for (MutableArcIterator<MutableFst<Arc> > aiter(ifst, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
Arc arc(aiter.Value());
|
||
|
arc.weight = Weight::One();
|
||
|
aiter.SetValue(arc);
|
||
|
}
|
||
|
if (ifst->Final(s) != Weight::Zero()) ifst->SetFinal(s, Weight::One());
|
||
|
}
|
||
|
ifst->SetProperties(kUnweighted, kUnweighted);
|
||
|
}
|
||
|
|
||
|
// Used in PrecedingInputSymbolsAreSame (non-functor version), and
|
||
|
// similar routines.
|
||
|
template <class T>
|
||
|
struct IdentityFunction {
|
||
|
typedef T Arg;
|
||
|
typedef T Result;
|
||
|
T operator()(const T &t) const { return t; }
|
||
|
};
|
||
|
|
||
|
template <class Arc>
|
||
|
bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc> &fst) {
|
||
|
IdentityFunction<typename Arc::Label> f;
|
||
|
return PrecedingInputSymbolsAreSameClass(start_is_epsilon, fst, f);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class F> // F is functor type from labels to classes.
|
||
|
bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon,
|
||
|
const Fst<Arc> &fst, const F &f) {
|
||
|
typedef typename F::Result ClassType;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
std::vector<ClassType> classes;
|
||
|
ClassType noClass = f(kNoLabel);
|
||
|
|
||
|
if (start_is_epsilon) {
|
||
|
StateId start_state = fst.Start();
|
||
|
if (start_state < 0 || start_state == kNoStateId)
|
||
|
return true; // empty fst-- doesn't matter.
|
||
|
classes.resize(start_state + 1, noClass);
|
||
|
classes[start_state] = 0;
|
||
|
}
|
||
|
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (classes.size() <= arc.nextstate)
|
||
|
classes.resize(arc.nextstate + 1, noClass);
|
||
|
if (classes[arc.nextstate] == noClass)
|
||
|
classes[arc.nextstate] = f(arc.ilabel);
|
||
|
else if (classes[arc.nextstate] != f(arc.ilabel))
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc> &fst) {
|
||
|
IdentityFunction<typename Arc::Label> f;
|
||
|
return FollowingInputSymbolsAreSameClass(end_is_epsilon, fst, f);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class F>
|
||
|
bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc> &fst,
|
||
|
const F &f) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
typedef typename F::Result ClassType;
|
||
|
const ClassType noClass = f(kNoLabel), epsClass = f(0);
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
ClassType c = noClass;
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (c == noClass)
|
||
|
c = f(arc.ilabel);
|
||
|
else if (c != f(arc.ilabel))
|
||
|
return false;
|
||
|
}
|
||
|
if (end_is_epsilon && c != noClass && c != epsClass &&
|
||
|
fst.Final(s) != Weight::Zero())
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void MakePrecedingInputSymbolsSame(bool start_is_epsilon,
|
||
|
MutableFst<Arc> *fst) {
|
||
|
IdentityFunction<typename Arc::Label> f;
|
||
|
MakePrecedingInputSymbolsSameClass(start_is_epsilon, fst, f);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class F>
|
||
|
void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon,
|
||
|
MutableFst<Arc> *fst, const F &f) {
|
||
|
typedef typename F::Result ClassType;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
std::vector<ClassType> classes;
|
||
|
ClassType noClass = f(kNoLabel);
|
||
|
ClassType epsClass = f(0);
|
||
|
if (start_is_epsilon) { // treat having-start-state as epsilon in-transition.
|
||
|
StateId start_state = fst->Start();
|
||
|
if (start_state < 0 || start_state == kNoStateId) // empty FST.
|
||
|
return;
|
||
|
classes.resize(start_state + 1, noClass);
|
||
|
classes[start_state] = epsClass;
|
||
|
}
|
||
|
|
||
|
// Find bad states (states with multiple input-symbols into them).
|
||
|
std::set<StateId> bad_states; // states that we need to change.
|
||
|
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (classes.size() <= static_cast<size_t>(arc.nextstate))
|
||
|
classes.resize(arc.nextstate + 1, noClass);
|
||
|
if (classes[arc.nextstate] == noClass)
|
||
|
classes[arc.nextstate] = f(arc.ilabel);
|
||
|
else if (classes[arc.nextstate] != f(arc.ilabel))
|
||
|
bad_states.insert(arc.nextstate);
|
||
|
}
|
||
|
}
|
||
|
if (bad_states.empty()) return; // Nothing to do.
|
||
|
kaldi::ConstIntegerSet<StateId> bad_states_ciset(
|
||
|
bad_states); // faster lookup.
|
||
|
|
||
|
// Work out list of arcs we have to change as (state, arc-offset).
|
||
|
// Can't do the actual changes in this pass, since we have to add new
|
||
|
// states which invalidates the iterators.
|
||
|
std::vector<std::pair<StateId, size_t> > arcs_to_change;
|
||
|
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (arc.ilabel != 0 && bad_states_ciset.count(arc.nextstate) != 0)
|
||
|
arcs_to_change.push_back(std::make_pair(s, aiter.Position()));
|
||
|
}
|
||
|
}
|
||
|
KALDI_ASSERT(!arcs_to_change.empty()); // since !bad_states.empty().
|
||
|
|
||
|
std::map<std::pair<StateId, ClassType>, StateId> state_map;
|
||
|
// state_map is a map from (bad-state, input-symbol-class) to dummy-state.
|
||
|
|
||
|
for (size_t i = 0; i < arcs_to_change.size(); i++) {
|
||
|
StateId s = arcs_to_change[i].first;
|
||
|
ArcIterator<MutableFst<Arc> > aiter(*fst, s);
|
||
|
aiter.Seek(arcs_to_change[i].second);
|
||
|
Arc arc = aiter.Value();
|
||
|
|
||
|
// Transition is non-eps transition to "bad" state. Introduce new state (or
|
||
|
// find existing one).
|
||
|
std::pair<StateId, ClassType> p(arc.nextstate, f(arc.ilabel));
|
||
|
if (state_map.count(p) == 0) {
|
||
|
StateId newstate = state_map[p] = fst->AddState();
|
||
|
fst->AddArc(newstate, Arc(0, 0, Weight::One(), arc.nextstate));
|
||
|
}
|
||
|
StateId dst_state = state_map[p];
|
||
|
arc.nextstate = dst_state;
|
||
|
|
||
|
// Initialize the MutableArcIterator only now, as the call to NewState()
|
||
|
// may have invalidated the first arc iterator.
|
||
|
MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
|
||
|
maiter.Seek(arcs_to_change[i].second);
|
||
|
maiter.SetValue(arc);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc> *fst) {
|
||
|
IdentityFunction<typename Arc::Label> f;
|
||
|
MakeFollowingInputSymbolsSameClass(end_is_epsilon, fst, f);
|
||
|
}
|
||
|
|
||
|
template <class Arc, class F>
|
||
|
void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon,
|
||
|
MutableFst<Arc> *fst, const F &f) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
typedef typename F::Result ClassType;
|
||
|
std::vector<StateId> bad_states;
|
||
|
ClassType noClass = f(kNoLabel);
|
||
|
ClassType epsClass = f(0);
|
||
|
for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
ClassType c = noClass;
|
||
|
bool bad = false;
|
||
|
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (c == noClass) {
|
||
|
c = f(arc.ilabel);
|
||
|
} else if (c != f(arc.ilabel)) {
|
||
|
bad = true;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
if (end_is_epsilon && c != noClass && c != epsClass &&
|
||
|
fst->Final(s) != Weight::Zero())
|
||
|
bad = true;
|
||
|
if (bad) bad_states.push_back(s);
|
||
|
}
|
||
|
std::vector<Arc> my_arcs;
|
||
|
for (size_t i = 0; i < bad_states.size(); i++) {
|
||
|
StateId s = bad_states[i];
|
||
|
my_arcs.clear();
|
||
|
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
|
||
|
aiter.Next())
|
||
|
my_arcs.push_back(aiter.Value());
|
||
|
|
||
|
for (size_t j = 0; j < my_arcs.size(); j++) {
|
||
|
Arc &arc = my_arcs[j];
|
||
|
if (arc.ilabel != 0) {
|
||
|
StateId newstate = fst->AddState();
|
||
|
// Create a new state for each non-eps arc in original FST, out of each
|
||
|
// bad state. Not as optimal as it could be, but does avoid some
|
||
|
// complicated weight-pushing issues in which, to maintain
|
||
|
// stochasticity, we would have to know which semiring we want to
|
||
|
// maintain stochasticity in.
|
||
|
fst->AddArc(newstate, Arc(arc.ilabel, 0, Weight::One(), arc.nextstate));
|
||
|
MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
|
||
|
maiter.Seek(j);
|
||
|
maiter.SetValue(Arc(0, arc.olabel, arc.weight, newstate));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
VectorFst<Arc> *MakeLoopFst(const std::vector<const ExpandedFst<Arc> *> &fsts) {
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Label Label;
|
||
|
|
||
|
VectorFst<Arc> *ans = new VectorFst<Arc>;
|
||
|
StateId loop_state = ans->AddState(); // = 0.
|
||
|
ans->SetStart(loop_state);
|
||
|
ans->SetFinal(loop_state, Weight::One());
|
||
|
|
||
|
// "cache" is used as an optimization when some of the pointers in "fsts"
|
||
|
// may have the same value.
|
||
|
unordered_map<const ExpandedFst<Arc> *, Arc> cache;
|
||
|
|
||
|
for (Label i = 0; i < static_cast<Label>(fsts.size()); i++) {
|
||
|
const ExpandedFst<Arc> *fst = fsts[i];
|
||
|
if (fst == NULL) continue;
|
||
|
{ // optimization with cache: helpful if some members of "fsts" may
|
||
|
// contain the same pointer value (e.g. in GetHTransducer).
|
||
|
typename unordered_map<const ExpandedFst<Arc> *, Arc>::iterator iter =
|
||
|
cache.find(fst);
|
||
|
if (iter != cache.end()) {
|
||
|
Arc arc = iter->second;
|
||
|
arc.olabel = i;
|
||
|
ans->AddArc(0, arc);
|
||
|
continue;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
KALDI_ASSERT(fst->Properties(kAcceptor, true) ==
|
||
|
kAcceptor); // expect acceptor.
|
||
|
|
||
|
StateId fst_num_states = fst->NumStates();
|
||
|
StateId fst_start_state = fst->Start();
|
||
|
|
||
|
if (fst_start_state == kNoStateId) continue; // empty fst.
|
||
|
|
||
|
bool share_start_state =
|
||
|
fst->Properties(kInitialAcyclic, true) == kInitialAcyclic &&
|
||
|
fst->NumArcs(fst_start_state) == 1 &&
|
||
|
fst->Final(fst_start_state) == Weight::Zero();
|
||
|
|
||
|
std::vector<StateId> state_map(fst_num_states); // fst state -> ans state
|
||
|
for (StateId s = 0; s < fst_num_states; s++) {
|
||
|
if (s == fst_start_state && share_start_state)
|
||
|
state_map[s] = loop_state;
|
||
|
else
|
||
|
state_map[s] = ans->AddState();
|
||
|
}
|
||
|
if (!share_start_state) {
|
||
|
Arc arc(0, i, Weight::One(), state_map[fst_start_state]);
|
||
|
cache[fst] = arc;
|
||
|
ans->AddArc(0, arc);
|
||
|
}
|
||
|
for (StateId s = 0; s < fst_num_states; s++) {
|
||
|
// Add arcs out of state s.
|
||
|
for (ArcIterator<ExpandedFst<Arc> > aiter(*fst, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
Label olabel = (s == fst_start_state && share_start_state ? i : 0);
|
||
|
Arc newarc(arc.ilabel, olabel, arc.weight, state_map[arc.nextstate]);
|
||
|
ans->AddArc(state_map[s], newarc);
|
||
|
if (s == fst_start_state && share_start_state) cache[fst] = newarc;
|
||
|
}
|
||
|
if (fst->Final(s) != Weight::Zero()) {
|
||
|
KALDI_ASSERT(!(s == fst_start_state && share_start_state));
|
||
|
ans->AddArc(state_map[s], Arc(0, 0, fst->Final(s), loop_state));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void ClearSymbols(bool clear_input, bool clear_output, MutableFst<Arc> *fst) {
|
||
|
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||
|
siter.Next()) {
|
||
|
typename Arc::StateId s = siter.Value();
|
||
|
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
Arc arc = aiter.Value();
|
||
|
bool change = false;
|
||
|
if (clear_input && arc.ilabel != 0) {
|
||
|
arc.ilabel = 0;
|
||
|
change = true;
|
||
|
}
|
||
|
if (clear_output && arc.olabel != 0) {
|
||
|
arc.olabel = 0;
|
||
|
change = true;
|
||
|
}
|
||
|
if (change) {
|
||
|
aiter.SetValue(arc);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void ApplyProbabilityScale(float scale, MutableFst<Arc> *fst) {
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||
|
siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
Arc arc = aiter.Value();
|
||
|
arc.weight = Weight(arc.weight.Value() * scale);
|
||
|
aiter.SetValue(arc);
|
||
|
}
|
||
|
if (fst->Final(s) != Weight::Zero())
|
||
|
fst->SetFinal(s, Weight(fst->Final(s).Value() * scale));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// return arc-offset of self-loop with ilabel (or -1 if none exists).
|
||
|
// if more than one such self-loop, pick first one.
|
||
|
template <class Arc>
|
||
|
ssize_t FindSelfLoopWithILabel(const Fst<Arc> &fst, typename Arc::StateId s) {
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next())
|
||
|
if (aiter.Value().nextstate == s && aiter.Value().ilabel != 0)
|
||
|
return static_cast<ssize_t>(aiter.Position());
|
||
|
return static_cast<ssize_t>(-1);
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
bool EqualAlign(const Fst<Arc> &ifst, typename Arc::StateId length,
|
||
|
int rand_seed, MutableFst<Arc> *ofst, int num_retries) {
|
||
|
srand(rand_seed);
|
||
|
KALDI_ASSERT(ofst->NumStates() == 0); // make sure ofst empty.
|
||
|
// make sure all states can reach final-state (or this algorithm may enter
|
||
|
// infinite loop.
|
||
|
KALDI_ASSERT(ifst.Properties(kCoAccessible, true) == kCoAccessible);
|
||
|
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
|
||
|
if (ifst.Start() == kNoStateId) {
|
||
|
KALDI_WARN << "Empty input fst.";
|
||
|
return false;
|
||
|
}
|
||
|
// First select path through ifst.
|
||
|
std::vector<StateId> path;
|
||
|
std::vector<size_t> arc_offsets; // arc taken out of each state.
|
||
|
std::vector<int> nof_ilabels;
|
||
|
|
||
|
StateId num_ilabels = 0;
|
||
|
int retry_no = 0;
|
||
|
|
||
|
// Under normal circumstances, this will be one-pass-only process
|
||
|
// Multiple tries might be needed in special cases, typically when
|
||
|
// the number of frames is close to number of transitions from
|
||
|
// the start node to the final node. It usually happens for really
|
||
|
// short utterances
|
||
|
do {
|
||
|
num_ilabels = 0;
|
||
|
arc_offsets.clear();
|
||
|
path.clear();
|
||
|
path.push_back(ifst.Start());
|
||
|
|
||
|
while (1) {
|
||
|
// Select either an arc or final-prob.
|
||
|
StateId s = path.back();
|
||
|
size_t num_arcs = ifst.NumArcs(s);
|
||
|
size_t num_arcs_tot = num_arcs;
|
||
|
if (ifst.Final(s) != Weight::Zero()) num_arcs_tot++;
|
||
|
// kaldi::RandInt is a bit like Rand(), but gets around situations
|
||
|
// where RAND_MAX is very small.
|
||
|
// Change this to Rand() % num_arcs_tot if compile issues arise
|
||
|
size_t arc_offset =
|
||
|
static_cast<size_t>(kaldi::RandInt(0, num_arcs_tot - 1));
|
||
|
|
||
|
if (arc_offset < num_arcs) { // an actual arc.
|
||
|
ArcIterator<Fst<Arc> > aiter(ifst, s);
|
||
|
aiter.Seek(arc_offset);
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (arc.nextstate == s) {
|
||
|
continue; // don't take this self-loop arc
|
||
|
} else {
|
||
|
arc_offsets.push_back(arc_offset);
|
||
|
path.push_back(arc.nextstate);
|
||
|
if (arc.ilabel != 0) num_ilabels++;
|
||
|
}
|
||
|
} else {
|
||
|
break; // Chose final-prob.
|
||
|
}
|
||
|
}
|
||
|
|
||
|
nof_ilabels.push_back(num_ilabels);
|
||
|
} while ((++retry_no < num_retries) && (num_ilabels > length));
|
||
|
|
||
|
if (num_ilabels > length) {
|
||
|
std::stringstream ilabel_vec;
|
||
|
std::copy(nof_ilabels.begin(), nof_ilabels.end(),
|
||
|
std::ostream_iterator<int>(ilabel_vec, ","));
|
||
|
std::string s = ilabel_vec.str();
|
||
|
s.erase(s.end() - 1);
|
||
|
KALDI_WARN << "EqualAlign: the randomly constructed paths lengths: " << s;
|
||
|
KALDI_WARN << "EqualAlign: utterance has too few frames " << length
|
||
|
<< " to align.";
|
||
|
return false; // can't make it shorter by adding self-loops!.
|
||
|
}
|
||
|
|
||
|
StateId num_self_loops = 0;
|
||
|
std::vector<ssize_t> self_loop_offsets(path.size());
|
||
|
for (size_t i = 0; i < path.size(); i++)
|
||
|
if ((self_loop_offsets[i] = FindSelfLoopWithILabel(ifst, path[i])) !=
|
||
|
static_cast<ssize_t>(-1))
|
||
|
num_self_loops++;
|
||
|
|
||
|
if (num_self_loops == 0 && num_ilabels < length) {
|
||
|
KALDI_WARN << "No self-loops on chosen path; cannot match length.";
|
||
|
return false; // no self-loops to make it longer.
|
||
|
}
|
||
|
|
||
|
StateId num_extra = length - num_ilabels; // Number of self-loops we need.
|
||
|
|
||
|
StateId min_num_loops = 0;
|
||
|
if (num_extra != 0)
|
||
|
min_num_loops = num_extra / num_self_loops; // prevent div by zero.
|
||
|
StateId num_with_one_more_loop = num_extra - (min_num_loops * num_self_loops);
|
||
|
KALDI_ASSERT(num_with_one_more_loop < num_self_loops || num_self_loops == 0);
|
||
|
|
||
|
ofst->AddState();
|
||
|
ofst->SetStart(0);
|
||
|
StateId cur_state = 0;
|
||
|
StateId counter = 0; // tell us when we should stop adding one more loop.
|
||
|
for (size_t i = 0; i < path.size(); i++) {
|
||
|
// First, add any self-loops that are necessary.
|
||
|
StateId num_loops = 0;
|
||
|
if (self_loop_offsets[i] != static_cast<ssize_t>(-1)) {
|
||
|
num_loops = min_num_loops + (counter < num_with_one_more_loop ? 1 : 0);
|
||
|
counter++;
|
||
|
}
|
||
|
for (StateId j = 0; j < num_loops; j++) {
|
||
|
ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
|
||
|
aiter.Seek(self_loop_offsets[i]);
|
||
|
Arc arc = aiter.Value();
|
||
|
KALDI_ASSERT(arc.nextstate == path[i] &&
|
||
|
arc.ilabel != 0); // make sure self-loop with ilabel.
|
||
|
StateId next_state = ofst->AddState();
|
||
|
ofst->AddArc(cur_state,
|
||
|
Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
|
||
|
cur_state = next_state;
|
||
|
}
|
||
|
if (i + 1 < path.size()) { // add forward transition.
|
||
|
ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
|
||
|
aiter.Seek(arc_offsets[i]);
|
||
|
Arc arc = aiter.Value();
|
||
|
KALDI_ASSERT(arc.nextstate == path[i + 1]);
|
||
|
StateId next_state = ofst->AddState();
|
||
|
ofst->AddArc(cur_state,
|
||
|
Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
|
||
|
cur_state = next_state;
|
||
|
} else { // add final-prob.
|
||
|
Weight weight = ifst.Final(path[i]);
|
||
|
KALDI_ASSERT(weight != Weight::Zero());
|
||
|
ofst->SetFinal(cur_state, weight);
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
// This function identifies two types of useless arcs:
|
||
|
// those where arc A and arc B both go from state X to
|
||
|
// state Y with the same input symbol (remove the one
|
||
|
// with smaller probability, or an arbitrary one if they
|
||
|
// are the same); and those where A is an arc from state X
|
||
|
// to state X, with epsilon input symbol [remove A].
|
||
|
// Only works for tropical (not log) semiring as it uses
|
||
|
// NaturalLess.
|
||
|
template <class Arc>
|
||
|
void RemoveUselessArcs(MutableFst<Arc> *fst) {
|
||
|
typedef typename Arc::Label Label;
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
NaturalLess<Weight> nl;
|
||
|
StateId non_coacc_state = kNoStateId;
|
||
|
size_t num_arcs_removed = 0, tot_arcs = 0;
|
||
|
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||
|
siter.Next()) {
|
||
|
std::vector<size_t> arcs_to_delete;
|
||
|
std::vector<Arc> arcs;
|
||
|
// pair2arclist lets us look up the arcs
|
||
|
std::map<std::pair<Label, StateId>, std::vector<size_t> > pair2arclist;
|
||
|
StateId state = siter.Value();
|
||
|
for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); !aiter.Done();
|
||
|
aiter.Next()) {
|
||
|
size_t pos = arcs.size();
|
||
|
const Arc &arc = aiter.Value();
|
||
|
arcs.push_back(arc);
|
||
|
pair2arclist[std::make_pair(arc.ilabel, arc.nextstate)].push_back(pos);
|
||
|
}
|
||
|
typename std::map<std::pair<Label, StateId>, std::vector<size_t> >::iterator
|
||
|
iter = pair2arclist.begin(),
|
||
|
end = pair2arclist.end();
|
||
|
for (; iter != end; ++iter) {
|
||
|
const std::vector<size_t> &poslist = iter->second;
|
||
|
if (poslist.size() > 1) { // >1 arc with same ilabel, dest-state
|
||
|
size_t best_pos = poslist[0];
|
||
|
Weight best_weight = arcs[best_pos].weight;
|
||
|
for (size_t j = 1; j < poslist.size(); j++) {
|
||
|
size_t pos = poslist[j];
|
||
|
Weight this_weight = arcs[pos].weight;
|
||
|
if (nl(this_weight,
|
||
|
best_weight)) { // NaturalLess seems to be somehow
|
||
|
// "backwards".
|
||
|
best_weight = this_weight; // found a better one.
|
||
|
best_pos = pos;
|
||
|
}
|
||
|
}
|
||
|
for (size_t j = 0; j < poslist.size(); j++)
|
||
|
if (poslist[j] != best_pos) arcs_to_delete.push_back(poslist[j]);
|
||
|
} else {
|
||
|
KALDI_ASSERT(poslist.size() == 1);
|
||
|
size_t pos = poslist[0];
|
||
|
Arc &arc = arcs[pos];
|
||
|
if (arc.ilabel == 0 && arc.nextstate == state)
|
||
|
arcs_to_delete.push_back(pos);
|
||
|
}
|
||
|
}
|
||
|
tot_arcs += arcs.size();
|
||
|
if (arcs_to_delete.size() != 0) {
|
||
|
num_arcs_removed += arcs_to_delete.size();
|
||
|
if (non_coacc_state == kNoStateId) non_coacc_state = fst->AddState();
|
||
|
MutableArcIterator<MutableFst<Arc> > maiter(fst, state);
|
||
|
for (size_t j = 0; j < arcs_to_delete.size(); j++) {
|
||
|
size_t pos = arcs_to_delete[j];
|
||
|
maiter.Seek(pos);
|
||
|
arcs[pos].nextstate = non_coacc_state;
|
||
|
maiter.SetValue(arcs[pos]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if (non_coacc_state != kNoStateId) Connect(fst);
|
||
|
KALDI_VLOG(1) << "removed " << num_arcs_removed << " of " << tot_arcs
|
||
|
<< "arcs.";
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void PhiCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
|
||
|
typename Arc::Label phi_label, MutableFst<Arc> *ofst) {
|
||
|
KALDI_ASSERT(phi_label !=
|
||
|
kNoLabel); // just use regular compose in this case.
|
||
|
typedef Fst<Arc> F;
|
||
|
typedef PhiMatcher<SortedMatcher<F> > PM;
|
||
|
CacheOptions base_opts;
|
||
|
base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
|
||
|
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
|
||
|
// The matcher for fst1 doesn't matter; we'll use fst2's matcher.
|
||
|
ComposeFstImplOptions<SortedMatcher<F>, PM> impl_opts(base_opts);
|
||
|
|
||
|
// the false below is something called phi_loop which is something I don't
|
||
|
// fully understand, but I don't think we want it.
|
||
|
|
||
|
// These pointers are taken ownership of, by ComposeFst.
|
||
|
PM *phi_matcher = new PM(fst2, MATCH_INPUT, phi_label, false);
|
||
|
SortedMatcher<F> *sorted_matcher =
|
||
|
new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
|
||
|
// not to use this matcher, as this would mean we would
|
||
|
// not follow phi transitions.
|
||
|
impl_opts.matcher1 = sorted_matcher;
|
||
|
impl_opts.matcher2 = phi_matcher;
|
||
|
*ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
|
||
|
Connect(ofst);
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void PropagateFinalInternal(typename Arc::Label phi_label,
|
||
|
typename Arc::StateId s, MutableFst<Arc> *fst) {
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
if (fst->Final(s) == Weight::Zero()) {
|
||
|
// search for phi transition. We assume there
|
||
|
// is just one-- phi nondeterminism is not allowed
|
||
|
// anyway.
|
||
|
int num_phis = 0;
|
||
|
for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
if (arc.ilabel == phi_label) {
|
||
|
num_phis++;
|
||
|
if (arc.nextstate == s) continue; // don't expect
|
||
|
// phi loops but ignore them anyway.
|
||
|
|
||
|
// If this recurses infinitely, it means there
|
||
|
// are loops of phi transitions, which there should
|
||
|
// not be in a normal backoff LM. We could make this
|
||
|
// routine work for this case, but currently there is
|
||
|
// no need.
|
||
|
PropagateFinalInternal(phi_label, arc.nextstate, fst);
|
||
|
if (fst->Final(arc.nextstate) != Weight::Zero())
|
||
|
fst->SetFinal(s, Times(fst->Final(arc.nextstate), arc.weight));
|
||
|
}
|
||
|
KALDI_ASSERT(num_phis <= 1 && "Phi nondeterminism found");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void PropagateFinal(typename Arc::Label phi_label, MutableFst<Arc> *fst) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
if (fst->Properties(kIEpsilons, true)) // just warn.
|
||
|
KALDI_WARN << "PropagateFinal: this may not work as desired "
|
||
|
"since your FST has input epsilons.";
|
||
|
StateId num_states = fst->NumStates();
|
||
|
for (StateId s = 0; s < num_states; s++)
|
||
|
PropagateFinalInternal(phi_label, s, fst);
|
||
|
}
|
||
|
|
||
|
template <class Arc>
|
||
|
void RhoCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
|
||
|
typename Arc::Label rho_label, MutableFst<Arc> *ofst) {
|
||
|
KALDI_ASSERT(rho_label !=
|
||
|
kNoLabel); // just use regular compose in this case.
|
||
|
typedef Fst<Arc> F;
|
||
|
typedef RhoMatcher<SortedMatcher<F> > RM;
|
||
|
CacheOptions base_opts;
|
||
|
base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
|
||
|
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
|
||
|
// The matcher for fst1 doesn't matter; we'll use fst2's matcher.
|
||
|
ComposeFstImplOptions<SortedMatcher<F>, RM> impl_opts(base_opts);
|
||
|
|
||
|
// the false below is something called rho_loop which is something I don't
|
||
|
// fully understand, but I don't think we want it.
|
||
|
|
||
|
// These pointers are taken ownership of, by ComposeFst.
|
||
|
RM *rho_matcher = new RM(fst2, MATCH_INPUT, rho_label);
|
||
|
SortedMatcher<F> *sorted_matcher =
|
||
|
new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
|
||
|
// not to use this matcher, as this would mean we would
|
||
|
// not follow rho transitions.
|
||
|
impl_opts.matcher1 = sorted_matcher;
|
||
|
impl_opts.matcher2 = rho_matcher;
|
||
|
*ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
|
||
|
Connect(ofst);
|
||
|
}
|
||
|
|
||
|
// Declare an override of the template below.
|
||
|
template <>
|
||
|
inline bool IsStochasticFst(const Fst<LogArc> &fst, float delta,
|
||
|
LogArc::Weight *min_sum, LogArc::Weight *max_sum);
|
||
|
|
||
|
// Will override this for LogArc where NaturalLess will not work.
|
||
|
template <class Arc>
|
||
|
inline bool IsStochasticFst(const Fst<Arc> &fst, float delta,
|
||
|
typename Arc::Weight *min_sum,
|
||
|
typename Arc::Weight *max_sum) {
|
||
|
typedef typename Arc::StateId StateId;
|
||
|
typedef typename Arc::Weight Weight;
|
||
|
NaturalLess<Weight> nl;
|
||
|
bool first_time = true;
|
||
|
bool ans = true;
|
||
|
if (min_sum) *min_sum = Arc::Weight::One();
|
||
|
if (max_sum) *max_sum = Arc::Weight::One();
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
Weight sum = fst.Final(s);
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
sum = Plus(sum, arc.weight);
|
||
|
}
|
||
|
if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
|
||
|
if (first_time) {
|
||
|
first_time = false;
|
||
|
if (max_sum) *max_sum = sum;
|
||
|
if (min_sum) *min_sum = sum;
|
||
|
} else {
|
||
|
if (max_sum && nl(*max_sum, sum)) *max_sum = sum;
|
||
|
if (min_sum && nl(sum, *min_sum)) *min_sum = sum;
|
||
|
}
|
||
|
}
|
||
|
if (first_time) { // just avoid NaNs if FST was empty.
|
||
|
if (max_sum) *max_sum = Weight::One();
|
||
|
if (min_sum) *min_sum = Weight::One();
|
||
|
}
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
// Overriding template for LogArc as NaturalLess does not work there.
|
||
|
template <>
|
||
|
inline bool IsStochasticFst(const Fst<LogArc> &fst, float delta,
|
||
|
LogArc::Weight *min_sum, LogArc::Weight *max_sum) {
|
||
|
typedef LogArc Arc;
|
||
|
typedef Arc::StateId StateId;
|
||
|
typedef Arc::Weight Weight;
|
||
|
bool first_time = true;
|
||
|
bool ans = true;
|
||
|
if (min_sum) *min_sum = LogArc::Weight::One();
|
||
|
if (max_sum) *max_sum = LogArc::Weight::One();
|
||
|
for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
|
||
|
StateId s = siter.Value();
|
||
|
Weight sum = fst.Final(s);
|
||
|
for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||
|
const Arc &arc = aiter.Value();
|
||
|
sum = Plus(sum, arc.weight);
|
||
|
}
|
||
|
if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
|
||
|
if (first_time) {
|
||
|
first_time = false;
|
||
|
if (max_sum) *max_sum = sum;
|
||
|
if (min_sum) *min_sum = sum;
|
||
|
} else {
|
||
|
// note that max and min are reversed from their normal
|
||
|
// meanings here (max and min w.r.t. the underlying probabilities).
|
||
|
if (max_sum && sum.Value() < max_sum->Value()) *max_sum = sum;
|
||
|
if (min_sum && sum.Value() > min_sum->Value()) *min_sum = sum;
|
||
|
}
|
||
|
}
|
||
|
if (first_time) { // just avoid NaNs if FST was empty.
|
||
|
if (max_sum) *max_sum = Weight::One();
|
||
|
if (min_sum) *min_sum = Weight::One();
|
||
|
}
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
// Tests whether a tropical FST is stochastic in the log
|
||
|
// semiring. (casts it and does the check.)
|
||
|
// This function deals with the generic fst.
|
||
|
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>.
|
||
|
// Otherwise, it will be died with an error.
|
||
|
inline bool IsStochasticFstInLog(const Fst<StdArc> &fst, float delta,
|
||
|
StdArc::Weight *min_sum,
|
||
|
StdArc::Weight *max_sum) {
|
||
|
bool ans = false;
|
||
|
LogArc::Weight log_min = LogArc::Weight::One(),
|
||
|
log_max = LogArc::Weight::Zero();
|
||
|
if (fst.Type() == "const") {
|
||
|
ConstFst<LogArc> logfst;
|
||
|
Cast(dynamic_cast<const ConstFst<StdArc> &>(fst), &logfst);
|
||
|
ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
|
||
|
} else if (fst.Type() == "vector") {
|
||
|
VectorFst<LogArc> logfst;
|
||
|
Cast(dynamic_cast<const VectorFst<StdArc> &>(fst), &logfst);
|
||
|
ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
|
||
|
} else {
|
||
|
KALDI_ERR << "This version currently supports ConstFst<StdArc> "
|
||
|
<< "or VectorFst<StdArc>";
|
||
|
}
|
||
|
if (min_sum) *min_sum = StdArc::Weight(log_min.Value());
|
||
|
if (max_sum) *max_sum = StdArc::Weight(log_max.Value());
|
||
|
return ans;
|
||
|
}
|
||
|
|
||
|
} // namespace fst.
|
||
|
|
||
|
#endif // KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
|