You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
268 lines
10 KiB
268 lines
10 KiB
// fstext/lattice-utils-inl.h
|
|
|
|
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
|
|
// Daniel Povey)
|
|
|
|
// See ../../COPYING for clarification regarding multiple authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
|
// See the Apache 2 License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|
|
#define KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|
|
// Do not include this file directly. It is included by lattice-utils.h
|
|
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace fst {
|
|
|
|
/* Convert from FST with arc-type Weight, to one with arc-type
|
|
CompactLatticeWeight. Uses FactorFst to identify chains
|
|
of states which can be turned into a single output arc. */
|
|
|
|
template <class Weight, class Int>
|
|
void ConvertLattice(
|
|
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *ofst,
|
|
bool invert) {
|
|
typedef ArcTpl<Weight> Arc;
|
|
typedef typename Arc::StateId StateId;
|
|
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
|
|
typedef ArcTpl<CompactWeight> CompactArc;
|
|
|
|
VectorFst<ArcTpl<Weight> > ffst;
|
|
std::vector<std::vector<Int> > labels;
|
|
if (invert) { // normal case: want the ilabels as sequences on the arcs of
|
|
Factor(ifst, &ffst, &labels); // the output... Factor makes seqs of
|
|
// ilabels.
|
|
} else {
|
|
VectorFst<ArcTpl<Weight> > invfst(ifst);
|
|
Invert(&invfst);
|
|
Factor(invfst, &ffst, &labels);
|
|
}
|
|
|
|
TopSort(&ffst); // Put the states in ffst in topological order, which is
|
|
// easier on the eye when reading the text-form lattices and corresponds to
|
|
// what we get when we generate the lattices in the decoder.
|
|
|
|
ofst->DeleteStates();
|
|
|
|
// The states will be numbered exactly the same as the original FST.
|
|
// Add the states to the new FST.
|
|
StateId num_states = ffst.NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
StateId news = ofst->AddState();
|
|
assert(news == s);
|
|
}
|
|
ofst->SetStart(ffst.Start());
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
Weight final_weight = ffst.Final(s);
|
|
if (final_weight != Weight::Zero()) {
|
|
CompactWeight final_compact_weight(final_weight, std::vector<Int>());
|
|
ofst->SetFinal(s, final_compact_weight);
|
|
}
|
|
for (ArcIterator<ExpandedFst<Arc> > iter(ffst, s); !iter.Done();
|
|
iter.Next()) {
|
|
const Arc &arc = iter.Value();
|
|
KALDI_PARANOID_ASSERT(arc.weight != Weight::Zero());
|
|
// note: zero-weight arcs not allowed anyway so weight should not be zero,
|
|
// but no harm in checking.
|
|
CompactArc compact_arc(arc.olabel, arc.olabel,
|
|
CompactWeight(arc.weight, labels[arc.ilabel]),
|
|
arc.nextstate);
|
|
ofst->AddArc(s, compact_arc);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class Weight, class Int>
|
|
void ConvertLattice(
|
|
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &ifst,
|
|
MutableFst<ArcTpl<Weight> > *ofst, bool invert) {
|
|
typedef ArcTpl<Weight> Arc;
|
|
typedef typename Arc::StateId StateId;
|
|
typedef typename Arc::Label Label;
|
|
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
|
|
typedef ArcTpl<CompactWeight> CompactArc;
|
|
ofst->DeleteStates();
|
|
// make the states in the new FST have the same numbers as
|
|
// the original ones, and add chains of states as necessary
|
|
// to encode the string-valued weights.
|
|
StateId num_states = ifst.NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
StateId news = ofst->AddState();
|
|
assert(news == s);
|
|
}
|
|
ofst->SetStart(ifst.Start());
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
CompactWeight final_weight = ifst.Final(s);
|
|
if (final_weight != CompactWeight::Zero()) {
|
|
StateId cur_state = s;
|
|
size_t string_length = final_weight.String().size();
|
|
for (size_t n = 0; n < string_length; n++) {
|
|
StateId next_state = ofst->AddState();
|
|
Label ilabel = 0;
|
|
Arc arc(ilabel, final_weight.String()[n],
|
|
(n == 0 ? final_weight.Weight() : Weight::One()), next_state);
|
|
if (invert) std::swap(arc.ilabel, arc.olabel);
|
|
ofst->AddArc(cur_state, arc);
|
|
cur_state = next_state;
|
|
}
|
|
ofst->SetFinal(cur_state,
|
|
string_length > 0 ? Weight::One() : final_weight.Weight());
|
|
}
|
|
for (ArcIterator<ExpandedFst<CompactArc> > iter(ifst, s); !iter.Done();
|
|
iter.Next()) {
|
|
const CompactArc &arc = iter.Value();
|
|
size_t string_length = arc.weight.String().size();
|
|
StateId cur_state = s;
|
|
// for all but the last element in the string--
|
|
// add a temporary state.
|
|
for (size_t n = 0; n + 1 < string_length; n++) {
|
|
StateId next_state = ofst->AddState();
|
|
Label ilabel = (n == 0 ? arc.ilabel : 0),
|
|
olabel = static_cast<Label>(arc.weight.String()[n]);
|
|
Weight weight = (n == 0 ? arc.weight.Weight() : Weight::One());
|
|
Arc new_arc(ilabel, olabel, weight, next_state);
|
|
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
|
|
ofst->AddArc(cur_state, new_arc);
|
|
cur_state = next_state;
|
|
}
|
|
Label ilabel = (string_length <= 1 ? arc.ilabel : 0),
|
|
olabel = (string_length > 0 ? arc.weight.String()[string_length - 1]
|
|
: 0);
|
|
Weight weight =
|
|
(string_length <= 1 ? arc.weight.Weight() : Weight::One());
|
|
Arc new_arc(ilabel, olabel, weight, arc.nextstate);
|
|
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
|
|
ofst->AddArc(cur_state, new_arc);
|
|
}
|
|
}
|
|
}
|
|
|
|
// This function converts lattices between float and double;
|
|
// it works for both CompactLatticeWeight and LatticeWeight.
|
|
template <class WeightIn, class WeightOut>
|
|
void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> > &ifst,
|
|
MutableFst<ArcTpl<WeightOut> > *ofst) {
|
|
typedef ArcTpl<WeightIn> ArcIn;
|
|
typedef ArcTpl<WeightOut> ArcOut;
|
|
typedef typename ArcIn::StateId StateId;
|
|
ofst->DeleteStates();
|
|
// The states will be numbered exactly the same as the original FST.
|
|
// Add the states to the new FST.
|
|
StateId num_states = ifst.NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
StateId news = ofst->AddState();
|
|
assert(news == s);
|
|
}
|
|
ofst->SetStart(ifst.Start());
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
WeightIn final_iweight = ifst.Final(s);
|
|
if (final_iweight != WeightIn::Zero()) {
|
|
WeightOut final_oweight;
|
|
ConvertLatticeWeight(final_iweight, &final_oweight);
|
|
ofst->SetFinal(s, final_oweight);
|
|
}
|
|
for (ArcIterator<ExpandedFst<ArcIn> > iter(ifst, s); !iter.Done();
|
|
iter.Next()) {
|
|
ArcIn arc = iter.Value();
|
|
KALDI_PARANOID_ASSERT(arc.weight != WeightIn::Zero());
|
|
ArcOut oarc;
|
|
ConvertLatticeWeight(arc.weight, &oarc.weight);
|
|
oarc.ilabel = arc.ilabel;
|
|
oarc.olabel = arc.olabel;
|
|
oarc.nextstate = arc.nextstate;
|
|
ofst->AddArc(s, oarc);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <class Weight, class ScaleFloat>
|
|
void ScaleLattice(const std::vector<std::vector<ScaleFloat> > &scale,
|
|
MutableFst<ArcTpl<Weight> > *fst) {
|
|
assert(scale.size() == 2 && scale[0].size() == 2 && scale[1].size() == 2);
|
|
if (scale == DefaultLatticeScale()) // nothing to do.
|
|
return;
|
|
typedef ArcTpl<Weight> Arc;
|
|
typedef MutableFst<Arc> Fst;
|
|
typedef typename Arc::StateId StateId;
|
|
StateId num_states = fst->NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
|
Arc arc = aiter.Value();
|
|
arc.weight = Weight(ScaleTupleWeight(arc.weight, scale));
|
|
aiter.SetValue(arc);
|
|
}
|
|
Weight final_weight = fst->Final(s);
|
|
if (final_weight != Weight::Zero())
|
|
fst->SetFinal(s, Weight(ScaleTupleWeight(final_weight, scale)));
|
|
}
|
|
}
|
|
|
|
template <class Weight, class Int>
|
|
void RemoveAlignmentsFromCompactLattice(
|
|
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst) {
|
|
typedef CompactLatticeWeightTpl<Weight, Int> W;
|
|
typedef ArcTpl<W> Arc;
|
|
typedef MutableFst<Arc> Fst;
|
|
typedef typename Arc::StateId StateId;
|
|
StateId num_states = fst->NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
|
Arc arc = aiter.Value();
|
|
arc.weight = W(arc.weight.Weight(), std::vector<Int>());
|
|
aiter.SetValue(arc);
|
|
}
|
|
W final_weight = fst->Final(s);
|
|
if (final_weight != W::Zero())
|
|
fst->SetFinal(s, W(final_weight.Weight(), std::vector<Int>()));
|
|
}
|
|
}
|
|
|
|
template <class Weight, class Int>
|
|
bool CompactLatticeHasAlignment(
|
|
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &fst) {
|
|
typedef CompactLatticeWeightTpl<Weight, Int> W;
|
|
typedef ArcTpl<W> Arc;
|
|
typedef ExpandedFst<Arc> Fst;
|
|
typedef typename Arc::StateId StateId;
|
|
StateId num_states = fst.NumStates();
|
|
for (StateId s = 0; s < num_states; s++) {
|
|
for (ArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
|
const Arc &arc = aiter.Value();
|
|
if (!arc.weight.String().empty()) return true;
|
|
}
|
|
W final_weight = fst.Final(s);
|
|
if (!final_weight.String().empty()) return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <class Real>
|
|
void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> > &ifst,
|
|
MutableFst<ArcTpl<LatticeWeightTpl<Real> > > *ofst) {
|
|
int32 num_states_cache = 50000;
|
|
fst::CacheOptions cache_opts(true, num_states_cache);
|
|
fst::MapFstOptions mapfst_opts(cache_opts);
|
|
StdToLatticeMapper<Real> mapper;
|
|
MapFst<StdArc, ArcTpl<LatticeWeightTpl<Real> >, StdToLatticeMapper<Real> >
|
|
map_fst(ifst, mapper, mapfst_opts);
|
|
*ofst = map_fst;
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|