commit
fdc189a386
@ -0,0 +1,158 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// todo refactor, repalce with gtest
|
||||
|
||||
#include "base/flags.h"
|
||||
#include "base/log.h"
|
||||
#include "decoder/ctc_tlg_decoder.h"
|
||||
#include "frontend/audio/data_cache.h"
|
||||
#include "kaldi/util/table-types.h"
|
||||
#include "nnet/decodable.h"
|
||||
#include "nnet/paddle_nnet.h"
|
||||
|
||||
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
|
||||
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
||||
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
|
||||
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
|
||||
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
|
||||
DEFINE_string(graph_path, "TLG", "decoder graph");
|
||||
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
|
||||
DEFINE_int32(max_active, 7500, "decoder graph");
|
||||
DEFINE_int32(receptive_field_length,
|
||||
7,
|
||||
"receptive field of two CNN(kernel=5) downsampling module.");
|
||||
DEFINE_int32(downsampling_rate,
|
||||
4,
|
||||
"two CNN(kernel=5) module downsampling rate.");
|
||||
DEFINE_string(model_output_names,
|
||||
"save_infer_model/scale_0.tmp_1,save_infer_model/"
|
||||
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
|
||||
"scale_3.tmp_1",
|
||||
"model output names");
|
||||
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
|
||||
|
||||
using kaldi::BaseFloat;
|
||||
using kaldi::Matrix;
|
||||
using std::vector;
|
||||
|
||||
// test TLG decoder by feeding speech feature.
|
||||
int main(int argc, char* argv[]) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
kaldi::SequentialBaseFloatMatrixReader feature_reader(
|
||||
FLAGS_feature_rspecifier);
|
||||
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
|
||||
std::string model_graph = FLAGS_model_path;
|
||||
std::string model_params = FLAGS_param_path;
|
||||
std::string word_symbol_table = FLAGS_word_symbol_table;
|
||||
std::string graph_path = FLAGS_graph_path;
|
||||
LOG(INFO) << "model path: " << model_graph;
|
||||
LOG(INFO) << "model param: " << model_params;
|
||||
LOG(INFO) << "word symbol path: " << word_symbol_table;
|
||||
LOG(INFO) << "graph path: " << graph_path;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
|
||||
ppspeech::TLGDecoderOptions opts;
|
||||
opts.word_symbol_table = word_symbol_table;
|
||||
opts.fst_path = graph_path;
|
||||
opts.opts.max_active = FLAGS_max_active;
|
||||
opts.opts.beam = 15.0;
|
||||
opts.opts.lattice_beam = 7.5;
|
||||
ppspeech::TLGDecoder decoder(opts);
|
||||
|
||||
ppspeech::ModelOptions model_opts;
|
||||
model_opts.model_path = model_graph;
|
||||
model_opts.params_path = model_params;
|
||||
model_opts.cache_shape = FLAGS_model_cache_names;
|
||||
model_opts.output_names = FLAGS_model_output_names;
|
||||
std::shared_ptr<ppspeech::PaddleNnet> nnet(
|
||||
new ppspeech::PaddleNnet(model_opts));
|
||||
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
|
||||
std::shared_ptr<ppspeech::Decodable> decodable(
|
||||
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
|
||||
|
||||
int32 chunk_size = FLAGS_receptive_field_length;
|
||||
int32 chunk_stride = FLAGS_downsampling_rate;
|
||||
int32 receptive_field_length = FLAGS_receptive_field_length;
|
||||
LOG(INFO) << "chunk size (frame): " << chunk_size;
|
||||
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
|
||||
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
|
||||
decoder.InitDecoder();
|
||||
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
string utt = feature_reader.Key();
|
||||
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
|
||||
raw_data->SetDim(feature.NumCols());
|
||||
LOG(INFO) << "process utt: " << utt;
|
||||
LOG(INFO) << "rows: " << feature.NumRows();
|
||||
LOG(INFO) << "cols: " << feature.NumCols();
|
||||
|
||||
int32 row_idx = 0;
|
||||
int32 padding_len = 0;
|
||||
int32 ori_feature_len = feature.NumRows();
|
||||
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
|
||||
padding_len =
|
||||
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
|
||||
feature.Resize(feature.NumRows() + padding_len,
|
||||
feature.NumCols(),
|
||||
kaldi::kCopyData);
|
||||
}
|
||||
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
|
||||
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
||||
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
|
||||
feature.NumCols());
|
||||
int32 feature_chunk_size = 0;
|
||||
if (ori_feature_len > chunk_idx * chunk_stride) {
|
||||
feature_chunk_size = std::min(
|
||||
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
|
||||
}
|
||||
if (feature_chunk_size < receptive_field_length) break;
|
||||
|
||||
int32 start = chunk_idx * chunk_stride;
|
||||
for (int row_id = 0; row_id < chunk_size; ++row_id) {
|
||||
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
|
||||
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
|
||||
feature_chunk.Data() + row_id * feature.NumCols(),
|
||||
feature.NumCols());
|
||||
f_chunk_tmp.CopyFromVec(tmp);
|
||||
++start;
|
||||
}
|
||||
raw_data->Accept(feature_chunk);
|
||||
if (chunk_idx == num_chunks - 1) {
|
||||
raw_data->SetFinished();
|
||||
}
|
||||
decoder.AdvanceDecode(decodable);
|
||||
}
|
||||
std::string result;
|
||||
result = decoder.GetFinalBestPath();
|
||||
decodable->Reset();
|
||||
decoder.Reset();
|
||||
if (result.empty()) {
|
||||
// the TokenWriter can not write empty string.
|
||||
++num_err;
|
||||
KALDI_LOG << " the result of " << utt << " is empty";
|
||||
continue;
|
||||
}
|
||||
KALDI_LOG << " the result of " << utt << " is " << result;
|
||||
result_writer.Write(utt, result);
|
||||
++num_done;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||
<< " with errors.";
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "decoder/ctc_tlg_decoder.h"
|
||||
namespace ppspeech {
|
||||
|
||||
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
|
||||
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
|
||||
CHECK(fst_ != nullptr);
|
||||
word_symbol_table_.reset(
|
||||
fst::SymbolTable::ReadText(opts.word_symbol_table));
|
||||
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
|
||||
decoder_->InitDecoding();
|
||||
frame_decoded_size_ = 0;
|
||||
}
|
||||
|
||||
void TLGDecoder::InitDecoder() {
|
||||
decoder_->InitDecoding();
|
||||
frame_decoded_size_ = 0;
|
||||
}
|
||||
|
||||
void TLGDecoder::AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
|
||||
while (!decodable->IsLastFrame(frame_decoded_size_)) {
|
||||
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
|
||||
AdvanceDecoding(decodable.get());
|
||||
}
|
||||
}
|
||||
|
||||
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
|
||||
decoder_->AdvanceDecoding(decodable, 1);
|
||||
frame_decoded_size_++;
|
||||
}
|
||||
|
||||
void TLGDecoder::Reset() {
|
||||
InitDecoder();
|
||||
return;
|
||||
}
|
||||
|
||||
std::string TLGDecoder::GetFinalBestPath() {
|
||||
decoder_->FinalizeDecoding();
|
||||
kaldi::Lattice lat;
|
||||
kaldi::LatticeWeight weight;
|
||||
std::vector<int> alignment;
|
||||
std::vector<int> words_id;
|
||||
decoder_->GetBestPath(&lat, true);
|
||||
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
|
||||
std::string words;
|
||||
for (int32 idx = 0; idx < words_id.size(); ++idx) {
|
||||
std::string word = word_symbol_table_->Find(words_id[idx]);
|
||||
words += word;
|
||||
}
|
||||
return words;
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "base/basic_types.h"
|
||||
#include "kaldi/decoder/decodable-itf.h"
|
||||
#include "kaldi/decoder/lattice-faster-online-decoder.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
namespace ppspeech {
|
||||
|
||||
struct TLGDecoderOptions {
|
||||
kaldi::LatticeFasterDecoderConfig opts;
|
||||
// todo remove later, add into decode resource
|
||||
std::string word_symbol_table;
|
||||
std::string fst_path;
|
||||
|
||||
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
|
||||
};
|
||||
|
||||
class TLGDecoder {
|
||||
public:
|
||||
explicit TLGDecoder(TLGDecoderOptions opts);
|
||||
void InitDecoder();
|
||||
void Decode();
|
||||
std::string GetBestPath();
|
||||
std::vector<std::pair<double, std::string>> GetNBestPath();
|
||||
std::string GetFinalBestPath();
|
||||
int NumFrameDecoded();
|
||||
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
|
||||
std::vector<std::string>& nbest_words);
|
||||
void AdvanceDecode(
|
||||
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
|
||||
|
||||
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
|
||||
// the frame size which have decoded starts from 0.
|
||||
int32 frame_decoded_size_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace ppspeech
|
@ -0,0 +1,6 @@
|
||||
|
||||
add_library(kaldi-decoder
|
||||
lattice-faster-decoder.cc
|
||||
lattice-faster-online-decoder.cc
|
||||
)
|
||||
target_link_libraries(kaldi-decoder PUBLIC kaldi-lat)
|
@ -0,0 +1,5 @@
|
||||
|
||||
add_library(kaldi-fstext
|
||||
kaldi-fst-io.cc
|
||||
)
|
||||
target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,144 @@
|
||||
// fstext/determinize-lattice.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
|
||||
#define KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "fstext/lattice-weight.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// \addtogroup fst_extensions
|
||||
/// @{
|
||||
|
||||
// For example of usage, see test-determinize-lattice.cc
|
||||
|
||||
/*
|
||||
DeterminizeLattice implements a special form of determinization
|
||||
with epsilon removal, optimized for a phase of lattice generation.
|
||||
Its input is an FST with weight-type BaseWeightType (usually a pair of
|
||||
floats, with a lexicographical type of order, such as
|
||||
LatticeWeightTpl<float>). Typically this would be a state-level lattice, with
|
||||
input symbols equal to words, and output-symbols equal to p.d.f's (so like
|
||||
the inverse of HCLG). Imagine representing this as an acceptor of type
|
||||
CompactLatticeWeightTpl<float>, in which the input/output symbols are words,
|
||||
and the weights contain the original weights together with strings (with zero
|
||||
or one symbol in them) containing the original output labels (the p.d.f.'s).
|
||||
We determinize this using acceptor determinization with epsilon removal.
|
||||
Remember (from lattice-weight.h) that CompactLatticeWeightTpl has a special
|
||||
kind of semiring where we always take the string corresponding to the best
|
||||
cost (of type BaseWeightType), and discard the other. This corresponds to
|
||||
taking the best output-label sequence (of p.d.f.'s) for each input-label
|
||||
sequence (of words). We couldn't use the Gallic weight for this, or it would
|
||||
die as soon as it detected that the input FST was non-functional. In our
|
||||
case, any acyclic FST (and many cyclic ones) can be determinized. We assume
|
||||
that there is a function Compare(const BaseWeightType &a, const
|
||||
BaseWeightType &b) that returns (-1, 0, 1) according to whether (a < b, a ==
|
||||
b, a > b) in the total order on the BaseWeightType... this information should
|
||||
be the same as NaturalLess would give, but it's more efficient to do it this
|
||||
way. You can define this for things like TropicalWeight if you need to
|
||||
instantiate this class for that weight type.
|
||||
|
||||
We implement this determinization in a special way to make it efficient for
|
||||
the types of FSTs that we will apply it to. One issue is that if we
|
||||
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
|
||||
type vector<IntType>, the algorithm takes time quadratic in the length of
|
||||
words (in states), because propagating each arc involves copying a whole
|
||||
vector (of integers representing p.d.f.'s). Instead we use a hash structure
|
||||
where each string is a pointer (Entry*), and uses a hash from (Entry*,
|
||||
IntType), to the successor string (and a way to get the latest IntType and
|
||||
the ancestor Entry*). [this is the class LatticeStringRepository].
|
||||
|
||||
Another issue is that rather than representing a determinized-state as a
|
||||
collection of (state, weight), we represent it in a couple of reduced forms.
|
||||
Suppose a determinized-state is a collection of (state, weight) pairs; call
|
||||
this the "canonical representation". Note: these collections are always
|
||||
normalized to remove any common weight and string part. Define end-states as
|
||||
the subset of states that have an arc out of them with a label on, or are
|
||||
final. If we represent a determinized-state a the set of just its
|
||||
(end-state, weight) pairs, this will be a valid and more compact
|
||||
representation, and will lead to a smaller set of determinized states (like
|
||||
early minimization). Call this collection of (end-state, weight) pairs the
|
||||
"minimal representation". As a mechanism to reduce compute, we can also
|
||||
consider another representation. In the determinization algorithm, we start
|
||||
off with a set of (begin-state, weight) pairs (where the "begin-states" are
|
||||
initial or have a label on the transition into them), and the "canonical
|
||||
representation" consists of the epsilon-closure of this set (i.e. follow
|
||||
epsilons). Call this set of (begin-state, weight) pairs, appropriately
|
||||
normalized, the "initial representation". If two initial representations are
|
||||
the same, the "canonical representation" and hence the "minimal
|
||||
representation" will be the same. We can use this to reduce compute. Note
|
||||
that if two initial representations are different, this does not preclude the
|
||||
other representations from being the same.
|
||||
|
||||
*/
|
||||
|
||||
struct DeterminizeLatticeOptions {
|
||||
float delta; // A small offset used to measure equality of weights.
|
||||
int max_mem; // If >0, determinization will fail and return false
|
||||
// when the algorithm's (approximate) memory consumption crosses this
|
||||
// threshold.
|
||||
int max_loop; // If >0, can be used to detect non-determinizable input
|
||||
// (a case that wouldn't be caught by max_mem).
|
||||
DeterminizeLatticeOptions() : delta(kDelta), max_mem(-1), max_loop(-1) {}
|
||||
};
|
||||
|
||||
/**
|
||||
This function implements the normal version of DeterminizeLattice, in which
|
||||
the output strings are represented using sequences of arcs, where all but
|
||||
the first one has an epsilon on the input side. The debug_ptr argument is
|
||||
an optional pointer to a bool that, if it becomes true while the algorithm
|
||||
is executing, the algorithm will print a traceback and terminate (used in
|
||||
fstdeterminizestar.cc debug non-terminating determinization). More
|
||||
efficient if ifst is arc-sorted on input label. If the number of arcs gets
|
||||
more than max_states, it will throw std::runtime_error (otherwise this code
|
||||
does not use exceptions). This is mainly useful for debug. */
|
||||
template <class Weight, class IntType>
|
||||
bool DeterminizeLattice(
|
||||
const Fst<ArcTpl<Weight> > &ifst, MutableFst<ArcTpl<Weight> > *ofst,
|
||||
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
|
||||
bool *debug_ptr = NULL);
|
||||
|
||||
/* This is a version of DeterminizeLattice with a slightly more "natural"
|
||||
output format, where the output sequences are encoded using the
|
||||
CompactLatticeArcTpl template (i.e. the sequences of output symbols are
|
||||
represented directly as strings) More efficient if ifst is arc-sorted on
|
||||
input label. If the #arcs gets more than max_arcs, it will throw
|
||||
std::runtime_error (otherwise this code does not use exceptions). This is
|
||||
mainly useful for debug.
|
||||
*/
|
||||
template <class Weight, class IntType>
|
||||
bool DeterminizeLattice(
|
||||
const Fst<ArcTpl<Weight> > &ifst,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
|
||||
DeterminizeLatticeOptions opts = DeterminizeLatticeOptions(),
|
||||
bool *debug_ptr = NULL);
|
||||
|
||||
/// @} end "addtogroup fst_extensions"
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/determinize-lattice-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_DETERMINIZE_LATTICE_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,116 @@
|
||||
// fstext/determinize-star.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2014 Guoguo Chen
|
||||
// 2015 Hainan Xu
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_H_
|
||||
#define KALDI_FSTEXT_DETERMINIZE_STAR_H_
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stdexcept> // this algorithm uses exceptions
|
||||
#include <vector>
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// \addtogroup fst_extensions
|
||||
/// @{
|
||||
|
||||
// For example of usage, see test-determinize-star.cc
|
||||
|
||||
/*
|
||||
DeterminizeStar implements determinization with epsilon removal, which we
|
||||
distinguish with a star.
|
||||
|
||||
We define a determinized* FST as one in which no state has more than one
|
||||
transition with the same input-label. Epsilon input labels are not allowed
|
||||
except starting from states that have exactly one arc exiting them (and are
|
||||
not final). [In the normal definition of determinized, epsilon-input labels
|
||||
are not allowed at all, whereas in Mohri's definition, epsilons are treated
|
||||
as ordinary symbols]. The determinized* definition is intended to simulate
|
||||
the effect of allowing strings of output symbols at each state.
|
||||
|
||||
The algorithm implemented here takes an Fst<Arc>, and a pointer to a
|
||||
MutableFst<Arc> where it puts its output. The weight type is assumed to be a
|
||||
float-weight. It does epsilon removal and determinization.
|
||||
This algorithm may fail if the input has epsilon cycles under
|
||||
certain circumstances (i.e. the semiring is non-idempotent, e.g. the log
|
||||
semiring, or there are negative cost epsilon cycles).
|
||||
|
||||
This implementation is much less fancy than the one in fst/determinize.h, and
|
||||
does not have an "on-demand" version.
|
||||
|
||||
The algorithm is a fairly normal determinization algorithm. We keep in
|
||||
memory the subsets of states, together with their leftover strings and their
|
||||
weights. The only difference is we detect input epsilon transitions and
|
||||
treat them "specially".
|
||||
*/
|
||||
|
||||
// This algorithm will be slightly faster if you sort the input fst on input
|
||||
// label.
|
||||
|
||||
/**
|
||||
This function implements the normal version of DeterminizeStar, in which the
|
||||
output strings are represented using sequences of arcs, where all but the
|
||||
first one has an epsilon on the input side. The debug_ptr argument is an
|
||||
optional pointer to a bool that, if it becomes true while the algorithm is
|
||||
executing, the algorithm will print a traceback and terminate (used in
|
||||
fstdeterminizestar.cc debug non-terminating determinization).
|
||||
If max_states is positive, it will stop determinization and throw an
|
||||
exception as soon as the max-states is reached. This can be useful in test.
|
||||
If allow_partial is true, the algorithm will output partial results when the
|
||||
specified max_states is reached (when larger than zero), instead of throwing
|
||||
out an error.
|
||||
|
||||
Caution, the return status is un-intuitive: this function will return false
|
||||
if determinization completed normally, and true if it was stopped early by
|
||||
reaching the 'max-states' limit, and a partial FST was generated.
|
||||
*/
|
||||
template <class F>
|
||||
bool DeterminizeStar(F &ifst, MutableFst<typename F::Arc> *ofst, // NOLINT
|
||||
float delta = kDelta, bool *debug_ptr = NULL,
|
||||
int max_states = -1, bool allow_partial = false);
|
||||
|
||||
/* This is a version of DeterminizeStar with a slightly more "natural" output
|
||||
format, where the output sequences are encoded using the GallicArc (i.e. the
|
||||
output symbols are strings. If max_states is positive, it will stop
|
||||
determinization and throw an exception as soon as the max-states is reached.
|
||||
This can be useful in test. If allow_partial is true, the algorithm will
|
||||
output partial results when the specified max_states is reached (when larger
|
||||
than zero), instead of throwing out an error.
|
||||
|
||||
Caution, the return status is un-intuitive: this function will return false
|
||||
if determinization completed normally, and true if it was stopped early by
|
||||
reaching the 'max-states' limit, and a partial FST was generated.
|
||||
*/
|
||||
template <class F>
|
||||
bool DeterminizeStar(F &ifst, // NOLINT
|
||||
MutableFst<GallicArc<typename F::Arc> > *ofst,
|
||||
float delta = kDelta, bool *debug_ptr = NULL,
|
||||
int max_states = -1, bool allow_partial = false);
|
||||
|
||||
/// @} end "addtogroup fst_extensions"
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/determinize-star-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_H_
|
@ -0,0 +1,34 @@
|
||||
// fstext/fstext-lib.h
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (author:
|
||||
// Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_FSTEXT_LIB_H_
|
||||
#define KALDI_FSTEXT_FSTEXT_LIB_H_
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/determinize-lattice.h"
|
||||
#include "fstext/determinize-star.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "fstext/lattice-utils.h"
|
||||
#include "fstext/lattice-weight.h"
|
||||
#include "fstext/pre-determinize.h"
|
||||
#include "fstext/table-matcher.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_FSTEXT_LIB_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,386 @@
|
||||
// fstext/fstext-utils.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2013 Guoguo Chen
|
||||
// 2014 Telepoint Global Hosting Service, LLC. (Author: David
|
||||
// Snyder)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_FSTEXT_UTILS_H_
|
||||
#define KALDI_FSTEXT_FSTEXT_UTILS_H_
|
||||
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "fstext/determinize-star.h"
|
||||
#include "fstext/remove-eps-local.h"
|
||||
#include "base/kaldi-common.h" // for error reporting macros.
|
||||
#include "util/text-utils.h" // for SplitStringToVector
|
||||
#include "fst/script/print-impl.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// Returns the highest numbered output symbol id of the FST (or zero
|
||||
/// for an empty FST.
|
||||
template <class Arc>
|
||||
typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc> &fst);
|
||||
|
||||
/// Returns the highest numbered input symbol id of the FST (or zero
|
||||
/// for an empty FST.
|
||||
template <class Arc>
|
||||
typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc> &fst);
|
||||
|
||||
/// Returns the total number of arcs in an FST.
|
||||
template <class Arc>
|
||||
typename Arc::StateId NumArcs(const ExpandedFst<Arc> &fst);
|
||||
|
||||
/// GetInputSymbols gets the list of symbols on the input of fst
|
||||
/// (including epsilon, if include_eps == true), as a sorted, unique
|
||||
/// list.
|
||||
template <class Arc, class I>
|
||||
void GetInputSymbols(const Fst<Arc> &fst, bool include_eps,
|
||||
std::vector<I> *symbols);
|
||||
|
||||
/// GetOutputSymbols gets the list of symbols on the output of fst
|
||||
/// (including epsilon, if include_eps == true)
|
||||
template <class Arc, class I>
|
||||
void GetOutputSymbols(const Fst<Arc> &fst, bool include_eps,
|
||||
std::vector<I> *symbols);
|
||||
|
||||
/// ClearSymbols sets all the symbols on the input and/or
|
||||
/// output side of the FST to zero, as specified.
|
||||
/// It does not alter the symbol tables.
|
||||
template <class Arc>
|
||||
void ClearSymbols(bool clear_input, bool clear_output, MutableFst<Arc> *fst);
|
||||
|
||||
template <class I>
|
||||
void GetSymbols(const SymbolTable &symtab, bool include_eps,
|
||||
std::vector<I> *syms_out);
|
||||
|
||||
inline void DeterminizeStarInLog(VectorFst<StdArc> *fst, float delta = kDelta,
|
||||
bool *debug_ptr = NULL, int max_states = -1);
|
||||
|
||||
// e.g. of using this function: PushInLog<REWEIGHT_TO_INITIAL>(fst,
|
||||
// kPushWeights|kPushLabels);
|
||||
|
||||
template <ReweightType rtype> // == REWEIGHT_TO_{INITIAL, FINAL}
|
||||
void PushInLog(VectorFst<StdArc> *fst, uint32 ptype, float delta = kDelta) {
|
||||
// PushInLog pushes the FST
|
||||
// and returns a new pushed FST (labels and weights pushed to the left).
|
||||
VectorFst<LogArc> *fst_log =
|
||||
new VectorFst<LogArc>; // Want to determinize in log semiring.
|
||||
Cast(*fst, fst_log);
|
||||
VectorFst<StdArc> tmp;
|
||||
*fst = tmp; // free up memory.
|
||||
VectorFst<LogArc> *fst_pushed_log = new VectorFst<LogArc>;
|
||||
Push<LogArc, rtype>(*fst_log, fst_pushed_log, ptype, delta);
|
||||
Cast(*fst_pushed_log, fst);
|
||||
delete fst_log;
|
||||
delete fst_pushed_log;
|
||||
}
|
||||
|
||||
// Minimizes after encoding; applicable to all FSTs. It is like what you get
|
||||
// from the Minimize() function, except it will not push the weights, or the
|
||||
// symbols. This is better for our recipes, as we avoid ever pushing the
|
||||
// weights. However, it will only minimize optimally if your graphs are such
|
||||
// that the symbols are as far to the left as they can go, and the weights
|
||||
// in combinable paths are the same... hard to formalize this, but it's
|
||||
// something that is satisified by our normal FSTs.
|
||||
template <class Arc>
|
||||
void MinimizeEncoded(VectorFst<Arc> *fst, float delta = kDelta) {
|
||||
Map(fst, QuantizeMapper<Arc>(delta));
|
||||
EncodeMapper<Arc> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
|
||||
Encode(fst, &encoder);
|
||||
internal::AcceptorMinimize(fst);
|
||||
Decode(fst, encoder);
|
||||
}
|
||||
|
||||
/// GetLinearSymbolSequence gets the symbol sequence from a linear FST.
|
||||
/// If the FST is not just a linear sequence, it returns false. If it is
|
||||
/// a linear sequence (including the empty FST), it returns true. In this
|
||||
/// case it outputs the symbol
|
||||
/// sequences as "isymbols_out" and "osymbols_out" (removing epsilons), and
|
||||
/// the total weight as "tot_weight". The total weight will be Weight::Zero()
|
||||
/// if the FST is empty. If any of the output pointers are NULL, it does not
|
||||
/// create that output.
|
||||
template <class Arc, class I>
|
||||
bool GetLinearSymbolSequence(const Fst<Arc> &fst, std::vector<I> *isymbols_out,
|
||||
std::vector<I> *osymbols_out,
|
||||
typename Arc::Weight *tot_weight_out);
|
||||
|
||||
/// This function converts an FST with a special structure, which is
|
||||
/// output by the OpenFst functions ShortestPath and RandGen, and converts
|
||||
/// them into a std::vector of separate FSTs. This special structure is that
|
||||
/// the only state that has more than one (arcs-out or final-prob) is the
|
||||
/// start state. fsts_out is resized to the appropriate size.
|
||||
template <class Arc>
|
||||
void ConvertNbestToVector(const Fst<Arc> &fst,
|
||||
std::vector<VectorFst<Arc> > *fsts_out);
|
||||
|
||||
/// Takes the n-shortest-paths (using ShortestPath), but outputs
|
||||
/// the result as a vector of up to n fsts. This function will
|
||||
/// size the "fsts_out" vector to however many paths it got
|
||||
/// (which will not exceed n). n must be >= 1.
|
||||
template <class Arc>
|
||||
void NbestAsFsts(const Fst<Arc> &fst, size_t n,
|
||||
std::vector<VectorFst<Arc> > *fsts_out);
|
||||
|
||||
/// Creates unweighted linear acceptor from symbol sequence.
|
||||
template <class Arc, class I>
|
||||
void MakeLinearAcceptor(const std::vector<I> &labels, MutableFst<Arc> *ofst);
|
||||
|
||||
/// Creates an unweighted acceptor with a linear structure, with alternatives
|
||||
/// at each position. Epsilon is treated like a normal symbol here.
|
||||
/// Each position in "labels" must have at least one alternative.
|
||||
template <class Arc, class I>
|
||||
void MakeLinearAcceptorWithAlternatives(
|
||||
const std::vector<std::vector<I> > &labels, MutableFst<Arc> *ofst);
|
||||
|
||||
/// Does PreDeterminize and DeterminizeStar and then removes the disambiguation
|
||||
/// symbols. This is a form of determinization that will never blow up. Note
|
||||
/// that ifst is non-const and can be considered to be destroyed by this
|
||||
/// operation.
|
||||
/// Does not do epsilon removal (RemoveEpsLocal)-- this is so it's safe to cast
|
||||
/// to log and do this, and maintain equivalence in tropical.
|
||||
|
||||
template <class Arc>
|
||||
void SafeDeterminizeWrapper(MutableFst<Arc> *ifst, MutableFst<Arc> *ofst,
|
||||
float delta = kDelta);
|
||||
|
||||
/// SafeDeterminizeMinimizeWapper is as SafeDeterminizeWrapper except that it
|
||||
/// also minimizes (encoded minimization, which is safe). This algorithm will
|
||||
/// destroy "ifst".
|
||||
template <class Arc>
|
||||
void SafeDeterminizeMinimizeWrapper(MutableFst<Arc> *ifst, VectorFst<Arc> *ofst,
|
||||
float delta = kDelta);
|
||||
|
||||
/// SafeDeterminizeMinimizeWapperInLog is as SafeDeterminizeMinimizeWrapper
|
||||
/// except it first casts tothe log semiring.
|
||||
void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc> *ifst,
|
||||
VectorFst<StdArc> *ofst,
|
||||
float delta = kDelta);
|
||||
|
||||
/// RemoveSomeInputSymbols removes any symbol that appears in "to_remove", from
|
||||
/// the input side of the FST, replacing them with epsilon.
|
||||
template <class Arc, class I>
|
||||
void RemoveSomeInputSymbols(const std::vector<I> &to_remove,
|
||||
MutableFst<Arc> *fst);
|
||||
|
||||
// MapInputSymbols will replace any input symbol i that is between 0 and
|
||||
// symbol_map.size()-1, with symbol_map[i]. It removes the input symbol
|
||||
// table of the FST.
|
||||
template <class Arc, class I>
|
||||
void MapInputSymbols(const std::vector<I> &symbol_map, MutableFst<Arc> *fst);
|
||||
|
||||
template <class Arc>
|
||||
void RemoveWeights(MutableFst<Arc> *fst);
|
||||
|
||||
/// Returns true if and only if the FST is such that the input symbols
|
||||
/// on arcs entering any given state all have the same value.
|
||||
/// if "start_is_epsilon", treat start-state as an epsilon input arc
|
||||
/// [i.e. ensure only epsilon can enter start-state].
|
||||
template <class Arc>
|
||||
bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc> &fst);
|
||||
|
||||
/// This is as PrecedingInputSymbolsAreSame, but with a functor f that maps
|
||||
/// labels to classes. The function tests whether the symbols preceding any
|
||||
/// given state are in the same class. Formally, f is of a type F that has an
|
||||
/// operator of type F::Result F::operator() (F::Arg a) const; where F::Result
|
||||
/// is an integer type and F::Arc can be constructed from Arc::Label. this must
|
||||
/// apply to valid labels and also to kNoLabel (so we can have a marker for the
|
||||
/// invalid labels.
|
||||
template <class Arc, class F>
|
||||
bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon,
|
||||
const Fst<Arc> &fst, const F &f);
|
||||
|
||||
/// Returns true if and only if the FST is such that the input symbols
|
||||
/// on arcs exiting any given state all have the same value.
|
||||
/// If end_is_epsilon, treat end-state as an epsilon output arc [i.e. ensure
|
||||
/// end-states cannot have non-epsilon output transitions.]
|
||||
template <class Arc>
|
||||
bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc> &fst);
|
||||
|
||||
template <class Arc, class F>
|
||||
bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc> &fst,
|
||||
const F &f);
|
||||
|
||||
/// MakePrecedingInputSymbolsSame ensures that all arcs entering any given fst
|
||||
/// state have the same input symbol. It does this by detecting states
|
||||
/// that have differing input symbols going in, and inserting, for each of
|
||||
/// the preceding arcs with non-epsilon input symbol, a new dummy state that
|
||||
/// has an epsilon link to the fst state.
|
||||
/// If "start_is_epsilon", ensure that start-state can have only epsilon-links
|
||||
/// into it.
|
||||
template <class Arc>
|
||||
void MakePrecedingInputSymbolsSame(bool start_is_epsilon, MutableFst<Arc> *fst);
|
||||
|
||||
/// As MakePrecedingInputSymbolsSame, but takes a functor object that maps
|
||||
/// labels to classes.
|
||||
template <class Arc, class F>
|
||||
void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon,
|
||||
MutableFst<Arc> *fst, const F &f);
|
||||
|
||||
/// MakeFollowingInputSymbolsSame ensures that all arcs exiting any given fst
|
||||
/// state have the same input symbol. It does this by detecting states that
|
||||
/// have differing input symbols on arcs that exit it, and inserting, for each
|
||||
/// of the following arcs with non-epsilon input symbol, a new dummy state that
|
||||
/// has an input-epsilon link from the fst state. The output symbol and weight
|
||||
/// stay on the link to the dummy state (in order to keep the FST
|
||||
/// output-deterministic and stochastic, if it already was). If end_is_epsilon,
|
||||
/// treat "being a final-state" like having an epsilon output link.
|
||||
template <class Arc>
|
||||
void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc> *fst);
|
||||
|
||||
/// As MakeFollowingInputSymbolsSame, but takes a functor object that maps
|
||||
/// labels to classes.
|
||||
template <class Arc, class F>
|
||||
void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon,
|
||||
MutableFst<Arc> *fst, const F &f);
|
||||
|
||||
/// MakeLoopFst creates an FST that has a state that is both initial and
|
||||
/// final (weight == Weight::One()), and for each non-NULL pointer fsts[i],
|
||||
/// it has an arc out whose output-symbol is i and which goes to a
|
||||
/// sub-graph whose input language is equivalent to fsts[i], where the
|
||||
/// final-state becomes a transition to the loop-state. Each fst in "fsts"
|
||||
/// should be an acceptor. The fst MakeLoopFst returns is output-deterministic,
|
||||
/// but not output-epsilon free necessarily, and arcs are sorted on output
|
||||
/// label. Note: if some of the pointers in the input vector "fsts" have the
|
||||
/// same value, "MakeLoopFst" uses this to speed up the computation.
|
||||
|
||||
/// Formally: suppose I is the set of indexes i such that fsts[i] != NULL.
|
||||
/// Let L[i] be the language that the acceptor fsts[i] accepts.
|
||||
/// Let the language K be the set of input-output pairs i:l such
|
||||
/// that i in I and l in L[i]. Then the FST returned by MakeLoopFst
|
||||
/// accepts the language K*, where * is the Kleene closure (CLOSURE_STAR)
|
||||
/// of K.
|
||||
|
||||
/// We could have implemented this via a combination of "project",
|
||||
/// "concat", "union" and "closure". But that FST would have been
|
||||
/// less well optimized and would have a lot of final-states.
|
||||
|
||||
template <class Arc>
|
||||
VectorFst<Arc> *MakeLoopFst(const std::vector<const ExpandedFst<Arc> *> &fsts);
|
||||
|
||||
/// ApplyProbabilityScale is applicable to FSTs in the log or tropical semiring.
|
||||
/// It multiplies the arc and final weights by "scale" [this is not the Mul
|
||||
/// operation of the semiring, it's actual multiplication, which is equivalent
|
||||
/// to taking a power in the semiring].
|
||||
template <class Arc>
|
||||
void ApplyProbabilityScale(float scale, MutableFst<Arc> *fst);
|
||||
|
||||
/// EqualAlign is similar to RandGen, but it generates a sequence with exactly
|
||||
/// "length" input symbols. It returns true on success, false on failure
|
||||
/// (failure is partly random but should never happen in practice for normal
|
||||
/// speech models.) It generates a random path through the input FST, finds out
|
||||
/// which subset of the states it visits along the way have self-loops with
|
||||
/// inupt symbols on them, and outputs a path with exactly enough self-loops to
|
||||
/// have the requested number of input symbols. Note that EqualAlign does not
|
||||
/// use the probabilities on the FST. It just uses equal probabilities in the
|
||||
/// first stage of selection (since the output will anyway not be a truly random
|
||||
/// sample from the FST). The input fst "ifst" must be connected or this may
|
||||
/// enter an infinite loop.
|
||||
template <class Arc>
|
||||
bool EqualAlign(const Fst<Arc> &ifst, typename Arc::StateId length,
|
||||
int rand_seed, MutableFst<Arc> *ofst, int num_retries = 10);
|
||||
|
||||
// RemoveUselessArcs removes arcs such that there is no input symbol
|
||||
// sequence for which the best path through the FST would contain
|
||||
// those arcs [for these purposes, epsilon is not treated as a real symbol].
|
||||
// This is mainly geared towards decoding-graph FSTs which may contain
|
||||
// transitions that have less likely words on them that would never be
|
||||
// taken. We do not claim that this algorithm removes all such arcs;
|
||||
// it just does the best job it can.
|
||||
// Only works for tropical (not log) semiring as it uses
|
||||
// NaturalLess.
|
||||
template <class Arc>
|
||||
void RemoveUselessArcs(MutableFst<Arc> *fst);
|
||||
|
||||
// PhiCompose is a version of composition where
|
||||
// the right hand FST (fst2) is treated as a backoff
|
||||
// LM, with the phi symbol (e.g. #0) treated as a
|
||||
// "failure transition", only taken when we don't
|
||||
// have a match for the requested symbol.
|
||||
template <class Arc>
|
||||
void PhiCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
|
||||
typename Arc::Label phi_label, MutableFst<Arc> *fst);
|
||||
|
||||
// PropagateFinal propagates final-probs through
|
||||
// "phi" transitions (note that here, phi_label may
|
||||
// be epsilon if you want). If you have a backoff LM
|
||||
// with special symbols ("phi") on the backoff arcs
|
||||
// instead of epsilon, you may use PhiCompose to compose
|
||||
// with it, but this won't do the right thing w.r.t.
|
||||
// final probabilities. You should first call PropagateFinal
|
||||
// on the FST with phi's i it (fst2 in PhiCompose above),
|
||||
// to fix this. If a state does not have a final-prob,
|
||||
// but has a phi transition, it makes the state's final-prob
|
||||
// (phi-prob * final-prob-of-dest-state), and does this
|
||||
// recursively i.e. follows phi transitions on the dest state
|
||||
// first. It behaves as if there were a super-final state
|
||||
// with a special symbol leading to it, from each currently
|
||||
// final state. Note that this may not behave as desired
|
||||
// if there are epsilons in your FST; it might be better
|
||||
// to remove those before calling this function.
|
||||
|
||||
template <class Arc>
|
||||
void PropagateFinal(typename Arc::Label phi_label, MutableFst<Arc> *fst);
|
||||
|
||||
// PhiCompose is a version of composition where
|
||||
// the right hand FST (fst2) has speciall "rho transitions"
|
||||
// which are taken whenever no normal transition matches; these
|
||||
// transitions will be rewritten with whatever symbol was on
|
||||
// the first FST.
|
||||
template <class Arc>
|
||||
void RhoCompose(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
|
||||
typename Arc::Label rho_label, MutableFst<Arc> *fst);
|
||||
|
||||
/** This function returns true if, in the semiring of the FST, the sum (within
|
||||
the semiring) of all the arcs out of each state in the FST is one, to within
|
||||
delta. After MakeStochasticFst, this should be true (for a connected FST).
|
||||
|
||||
@param fst [in] the FST that we are testing.
|
||||
@param delta [in] the tolerance to within which we test equality to 1.
|
||||
@param min_sum [out] if non, NULL, contents will be set to the minimum sum
|
||||
of weights.
|
||||
@param max_sum [out] if non, NULL, contents will be set to the maximum sum
|
||||
of weights.
|
||||
@return Returns true if the FST is stochastic, and false otherwise.
|
||||
*/
|
||||
|
||||
template <class Arc>
|
||||
bool IsStochasticFst(const Fst<Arc> &fst,
|
||||
float delta = kDelta, // kDelta = 1.0/1024.0 by default.
|
||||
typename Arc::Weight *min_sum = NULL,
|
||||
typename Arc::Weight *max_sum = NULL);
|
||||
|
||||
// IsStochasticFstInLog makes sure it's stochastic after casting to log.
|
||||
inline bool IsStochasticFstInLog(
|
||||
const Fst<StdArc> &fst,
|
||||
float delta = kDelta, // kDelta = 1.0/1024.0 by default.
|
||||
StdArc::Weight *min_sum = NULL, StdArc::Weight *max_sum = NULL);
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/fstext-utils-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_FSTEXT_UTILS_H_
|
@ -0,0 +1,208 @@
|
||||
// fstext/kaldi-fst-io-inl.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2013 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_KALDI_FST_IO_INL_H_
|
||||
#define KALDI_FSTEXT_KALDI_FST_IO_INL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "util/text-utils.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
template <class Arc>
|
||||
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst<Arc> &t) {
|
||||
bool ok;
|
||||
if (binary) {
|
||||
// Binary-mode writing.
|
||||
ok = t.Write(os, FstWriteOptions());
|
||||
} else {
|
||||
// Text-mode output. Note: we expect that t.InputSymbols() and
|
||||
// t.OutputSymbols() would always return NULL. The corresponding input
|
||||
// routine would not work if the FST actually had symbols attached. Write a
|
||||
// newline to start the FST; in a table, the first line of the FST will
|
||||
// appear on its own line.
|
||||
os << '\n';
|
||||
bool acceptor = false, write_one = false;
|
||||
FstPrinter<Arc> printer(t, t.InputSymbols(), t.OutputSymbols(), NULL,
|
||||
acceptor, write_one, "\t");
|
||||
printer.Print(&os, "<unknown>");
|
||||
if (os.fail()) KALDI_ERR << "Stream failure detected writing FST to stream";
|
||||
// Write another newline as a terminating character. The read routine will
|
||||
// detect this [this is a Kaldi mechanism, not something in the original
|
||||
// OpenFst code].
|
||||
os << '\n';
|
||||
ok = os.good();
|
||||
}
|
||||
if (!ok) {
|
||||
KALDI_ERR << "Error writing FST to stream";
|
||||
}
|
||||
}
|
||||
|
||||
// Utility function used in ReadFstKaldi
|
||||
template <class W>
|
||||
inline bool StrToWeight(const std::string &s, bool allow_zero, W *w) {
|
||||
std::istringstream strm(s);
|
||||
strm >> *w;
|
||||
if (strm.fail() || (!allow_zero && *w == W::Zero())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class Arc>
|
||||
void ReadFstKaldi(std::istream &is, bool binary, VectorFst<Arc> *fst) {
|
||||
typedef typename Arc::Weight Weight;
|
||||
typedef typename Arc::StateId StateId;
|
||||
if (binary) {
|
||||
// We don't have access to the filename here, so write [unknown].
|
||||
VectorFst<Arc> *ans =
|
||||
VectorFst<Arc>::Read(is, fst::FstReadOptions(std::string("[unknown]")));
|
||||
if (ans == NULL) {
|
||||
KALDI_ERR << "Error reading FST from stream.";
|
||||
}
|
||||
*fst = *ans; // shallow copy.
|
||||
delete ans;
|
||||
} else {
|
||||
// Consume the \r on Windows, the \n that the text-form FST format starts
|
||||
// with, and any extra spaces that might have got in there somehow.
|
||||
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
|
||||
if (is.peek() == '\n') {
|
||||
is.get(); // consume the newline.
|
||||
} else { // saw spaces but no newline.. this is not expected.
|
||||
KALDI_ERR << "Reading FST: unexpected sequence of spaces "
|
||||
<< " at file position " << is.tellg();
|
||||
}
|
||||
using kaldi::ConvertStringToInteger;
|
||||
using kaldi::SplitStringToIntegers;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
fst->DeleteStates();
|
||||
string line;
|
||||
size_t nline = 0;
|
||||
string separator = FLAGS_fst_field_separator + "\r\n";
|
||||
while (std::getline(is, line)) {
|
||||
nline++;
|
||||
vector<string> col;
|
||||
// on Windows we'll write in text and read in binary mode.
|
||||
kaldi::SplitStringToVector(line, separator.c_str(), true, &col);
|
||||
if (col.size() == 0) break; // Empty line is a signal to stop, in our
|
||||
// archive format.
|
||||
if (col.size() > 5) {
|
||||
KALDI_ERR << "Bad line in FST: " << line;
|
||||
}
|
||||
StateId s;
|
||||
if (!ConvertStringToInteger(col[0], &s)) {
|
||||
KALDI_ERR << "Bad line in FST: " << line;
|
||||
}
|
||||
while (s >= fst->NumStates()) fst->AddState();
|
||||
if (nline == 1) fst->SetStart(s);
|
||||
|
||||
bool ok = true;
|
||||
Arc arc;
|
||||
Weight w;
|
||||
StateId d = s;
|
||||
switch (col.size()) {
|
||||
case 1:
|
||||
fst->SetFinal(s, Weight::One());
|
||||
break;
|
||||
case 2:
|
||||
if (!StrToWeight(col[1], true, &w))
|
||||
ok = false;
|
||||
else
|
||||
fst->SetFinal(s, w);
|
||||
break;
|
||||
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
|
||||
ok = false;
|
||||
break;
|
||||
case 4:
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||
ConvertStringToInteger(col[3], &arc.olabel);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
arc.weight = Weight::One();
|
||||
fst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
case 5:
|
||||
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
|
||||
ConvertStringToInteger(col[2], &arc.ilabel) &&
|
||||
ConvertStringToInteger(col[3], &arc.olabel) &&
|
||||
StrToWeight(col[4], false, &arc.weight);
|
||||
if (ok) {
|
||||
d = arc.nextstate;
|
||||
fst->AddArc(s, arc);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ok = false;
|
||||
}
|
||||
while (d >= fst->NumStates()) fst->AddState();
|
||||
if (!ok) KALDI_ERR << "Bad line in FST: " << line;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Arc> // static
|
||||
bool VectorFstTplHolder<Arc>::Write(std::ostream &os, bool binary, const T &t) {
|
||||
try {
|
||||
WriteFstKaldi(os, binary, t);
|
||||
return true;
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <class Arc> // static
|
||||
bool VectorFstTplHolder<Arc>::Read(std::istream &is) {
|
||||
Clear();
|
||||
int c = is.peek();
|
||||
if (c == -1) {
|
||||
KALDI_WARN << "End of stream detected reading Fst";
|
||||
return false;
|
||||
} else if (isspace(c)) { // The text form of the FST begins
|
||||
// with space (normally, '\n'), so this means it's text (the binary form
|
||||
// cannot begin with space because it starts with the FST Type() which is
|
||||
// not space).
|
||||
try {
|
||||
t_ = new VectorFst<Arc>();
|
||||
ReadFstKaldi(is, false, t_);
|
||||
} catch (...) {
|
||||
Clear();
|
||||
return false;
|
||||
}
|
||||
} else { // reading a binary FST.
|
||||
try {
|
||||
t_ = new VectorFst<Arc>();
|
||||
ReadFstKaldi(is, true, t_);
|
||||
} catch (...) {
|
||||
Clear();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace fst.
|
||||
|
||||
#endif // KALDI_FSTEXT_KALDI_FST_IO_INL_H_
|
@ -0,0 +1,148 @@
|
||||
// fstext/kaldi-fst-io.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2013 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "base/kaldi-error.h"
|
||||
#include "base/kaldi-math.h"
|
||||
#include "util/kaldi-io.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
VectorFst<StdArc> *ReadFstKaldi(std::string rxfilename) {
|
||||
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
|
||||
// for compatibility with OpenFst conventions.
|
||||
kaldi::Input ki(rxfilename);
|
||||
fst::FstHeader hdr;
|
||||
if (!hdr.Read(ki.Stream(), rxfilename))
|
||||
KALDI_ERR << "Reading FST: error reading FST header from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename);
|
||||
FstReadOptions ropts("<unspecified>", &hdr);
|
||||
VectorFst<StdArc> *fst = VectorFst<StdArc>::Read(ki.Stream(), ropts);
|
||||
if (!fst)
|
||||
KALDI_ERR << "Could not read fst from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename);
|
||||
return fst;
|
||||
}
|
||||
|
||||
// Register const fst to load it automatically. Other types like
|
||||
// olabel_lookahead or ngram or compact_fst should be registered
|
||||
// through OpenFst registration API.
|
||||
static fst::FstRegisterer<VectorFst<StdArc>> VectorFst_StdArc_registerer;
|
||||
static fst::FstRegisterer<ConstFst<StdArc>> ConstFst_StdArc_registerer;
|
||||
|
||||
Fst<StdArc> *ReadFstKaldiGeneric(std::string rxfilename, bool throw_on_err) {
|
||||
if (rxfilename == "") rxfilename = "-"; // interpret "" as stdin,
|
||||
// for compatibility with OpenFst conventions.
|
||||
kaldi::Input ki(rxfilename);
|
||||
fst::FstHeader hdr;
|
||||
// Read FstHeader which contains the type of FST
|
||||
if (!hdr.Read(ki.Stream(), rxfilename)) {
|
||||
if (throw_on_err) {
|
||||
KALDI_ERR << "Reading FST: error reading FST header from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename);
|
||||
} else {
|
||||
KALDI_WARN << "We fail to read FST header from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename)
|
||||
<< ". A NULL pointer is returned.";
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
// Check the type of Arc
|
||||
if (hdr.ArcType() != fst::StdArc::Type()) {
|
||||
if (throw_on_err) {
|
||||
KALDI_ERR << "FST with arc type " << hdr.ArcType()
|
||||
<< " is not supported.";
|
||||
} else {
|
||||
KALDI_WARN << "Fst with arc type" << hdr.ArcType()
|
||||
<< " is not supported. A NULL pointer is returned.";
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
// Read the FST
|
||||
FstReadOptions ropts("<unspecified>", &hdr);
|
||||
Fst<StdArc> *fst = Fst<StdArc>::Read(ki.Stream(), ropts);
|
||||
if (!fst) {
|
||||
if (throw_on_err) {
|
||||
KALDI_ERR << "Could not read fst from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename);
|
||||
} else {
|
||||
KALDI_WARN << "Could not read fst from "
|
||||
<< kaldi::PrintableRxfilename(rxfilename)
|
||||
<< ". A NULL pointer is returned.";
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return fst;
|
||||
}
|
||||
|
||||
VectorFst<StdArc> *CastOrConvertToVectorFst(Fst<StdArc> *fst) {
|
||||
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>
|
||||
std::string real_type = fst->Type();
|
||||
KALDI_ASSERT(real_type == "vector" || real_type == "const");
|
||||
if (real_type == "vector") {
|
||||
return dynamic_cast<VectorFst<StdArc> *>(fst);
|
||||
} else {
|
||||
// As the 'fst' can't cast to VectorFst, we create a new
|
||||
// VectorFst<StdArc> initialized by 'fst', and delete 'fst'.
|
||||
VectorFst<StdArc> *new_fst = new VectorFst<StdArc>(*fst);
|
||||
delete fst;
|
||||
return new_fst;
|
||||
}
|
||||
}
|
||||
|
||||
void ReadFstKaldi(std::string rxfilename, fst::StdVectorFst *ofst) {
|
||||
fst::StdVectorFst *fst = ReadFstKaldi(rxfilename);
|
||||
*ofst = *fst;
|
||||
delete fst;
|
||||
}
|
||||
|
||||
void WriteFstKaldi(const VectorFst<StdArc> &fst, std::string wxfilename) {
|
||||
if (wxfilename == "") wxfilename = "-"; // interpret "" as stdout,
|
||||
// for compatibility with OpenFst conventions.
|
||||
bool write_binary = true, write_header = false;
|
||||
kaldi::Output ko(wxfilename, write_binary, write_header);
|
||||
FstWriteOptions wopts(kaldi::PrintableWxfilename(wxfilename));
|
||||
fst.Write(ko.Stream(), wopts);
|
||||
}
|
||||
|
||||
fst::VectorFst<fst::StdArc> *ReadAndPrepareLmFst(std::string rxfilename) {
|
||||
// ReadFstKaldi() will die with exception on failure.
|
||||
fst::VectorFst<fst::StdArc> *ans = fst::ReadFstKaldi(rxfilename);
|
||||
if (ans->Properties(fst::kAcceptor, true) == 0) {
|
||||
// If it's not already an acceptor, project on the output, i.e. copy olabels
|
||||
// to ilabels. Generally the G.fst's on disk will have the disambiguation
|
||||
// symbol #0 on the input symbols of the backoff arc, and projection will
|
||||
// replace them with epsilons which is what is on the output symbols of
|
||||
// those arcs.
|
||||
fst::Project(ans, fst::PROJECT_OUTPUT);
|
||||
}
|
||||
if (ans->Properties(fst::kILabelSorted, true) == 0) {
|
||||
// Make sure LM is sorted on ilabel.
|
||||
fst::ILabelCompare<fst::StdArc> ilabel_comp;
|
||||
fst::ArcSort(ans, ilabel_comp);
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // end namespace fst
|
@ -0,0 +1,158 @@
|
||||
// fstext/kaldi-fst-io.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2012-2015 Johns Hopkins University (Author: Daniel Povey)
|
||||
// 2013 Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_KALDI_FST_IO_H_
|
||||
#define KALDI_FSTEXT_KALDI_FST_IO_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "fst/fst-decl.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/script/print-impl.h"
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
|
||||
// Some functions for writing Fsts.
|
||||
// I/O for FSTs is a bit of a mess, and not very well integrated with Kaldi's
|
||||
// generic I/O mechanisms, because we want files containing just FSTs to
|
||||
// be readable by OpenFST's native binaries, which is not compatible
|
||||
// with the normal \0B header that identifies Kaldi files as containing
|
||||
// binary data.
|
||||
// So use the functions here with your eyes open, and with caution!
|
||||
namespace fst {
|
||||
|
||||
// Read a binary FST using Kaldi I/O mechanisms (pipes, etc.)
|
||||
// On error returns NULL. Only supports VectorFst and exists
|
||||
// mainly for backward code compabibility.
|
||||
VectorFst<StdArc> *ReadFstKaldi(std::string rxfilename);
|
||||
|
||||
// Read a binary FST using Kaldi I/O mechanisms (pipes, etc.)
|
||||
// If it can't read the FST, if throw_on_err == true it throws using KALDI_ERR;
|
||||
// otherwise it prints a warning and returns. Note:this
|
||||
// doesn't support the text-mode option that we generally like to support.
|
||||
// This version currently supports ConstFst<StdArc> or VectorFst<StdArc>
|
||||
// (const-fst can give better performance for decoding). Other
|
||||
// types could be also loaded if registered inside OpenFst.
|
||||
Fst<StdArc> *ReadFstKaldiGeneric(std::string rxfilename,
|
||||
bool throw_on_err = true);
|
||||
|
||||
// This function attempts to dynamic_cast the pointer 'fst' (which will likely
|
||||
// have been returned by ReadFstGeneric()), to the more derived
|
||||
// type VectorFst<StdArc>. If this succeeds, it returns the same pointer;
|
||||
// if it fails, it converts the FST type (by creating a new VectorFst<stdArc>
|
||||
// initialized by 'fst'), prints a warning, and deletes 'fst'.
|
||||
VectorFst<StdArc> *CastOrConvertToVectorFst(Fst<StdArc> *fst);
|
||||
|
||||
// Version of ReadFstKaldi() that writes to a pointer. Assumes
|
||||
// the FST is binary with no binary marker. Crashes on error.
|
||||
void ReadFstKaldi(std::string rxfilename, VectorFst<StdArc> *ofst);
|
||||
|
||||
// Write an FST using Kaldi I/O mechanisms (pipes, etc.)
|
||||
// On error, throws using KALDI_ERR. For use only in code in fstbin/,
|
||||
// as it doesn't support the text-mode option.
|
||||
void WriteFstKaldi(const VectorFst<StdArc> &fst, std::string wxfilename);
|
||||
|
||||
// This is a more general Kaldi-type-IO mechanism of writing FSTs to
|
||||
// streams, supporting binary or text-mode writing. (note: we just
|
||||
// write the integers, symbol tables are not supported).
|
||||
// On error, throws using KALDI_ERR.
|
||||
template <class Arc>
|
||||
void WriteFstKaldi(std::ostream &os, bool binary, const VectorFst<Arc> &fst);
|
||||
|
||||
// A generic Kaldi-type-IO mechanism of reading FSTs from streams,
|
||||
// supporting binary or text-mode reading/writing.
|
||||
template <class Arc>
|
||||
void ReadFstKaldi(std::istream &is, bool binary, VectorFst<Arc> *fst);
|
||||
|
||||
// Read an FST file for LM (G.fst) and make it an acceptor,
|
||||
// and make sure it is sorted on labels
|
||||
fst::VectorFst<fst::StdArc> *ReadAndPrepareLmFst(std::string rxfilename);
|
||||
|
||||
// This is a Holder class with T = VectorFst<Arc>, that meets the requirements
|
||||
// of a Holder class as described in ../util/kaldi-holder.h. This enables us to
|
||||
// read/write collections of FSTs indexed by strings, using the Table concept (
|
||||
// see ../util/kaldi-table.h).
|
||||
// Originally it was only templated on T = VectorFst<StdArc>, but as the keyword
|
||||
// spotting stuff introduced more types of FSTs, we made it also templated on
|
||||
// the arc.
|
||||
template <class Arc>
|
||||
class VectorFstTplHolder {
|
||||
public:
|
||||
typedef VectorFst<Arc> T;
|
||||
|
||||
VectorFstTplHolder() : t_(NULL) {}
|
||||
|
||||
static bool Write(std::ostream &os, bool binary, const T &t);
|
||||
|
||||
void Copy(const T &t) { // copies it into the holder.
|
||||
Clear();
|
||||
t_ = new T(t);
|
||||
}
|
||||
|
||||
// Reads into the holder.
|
||||
bool Read(std::istream &is);
|
||||
|
||||
// It's potentially a binary format, so must read in binary mode (linefeed
|
||||
// translation will corrupt the file. We don't know till we open the file if
|
||||
// it's really binary, so we need to read in binary mode to be on the safe
|
||||
// side. Extra linefeeds won't matter, the text-mode reading code ignores
|
||||
// them.
|
||||
static bool IsReadInBinary() { return true; }
|
||||
|
||||
T &Value() {
|
||||
// code error if !t_.
|
||||
if (!t_) KALDI_ERR << "VectorFstTplHolder::Value() called wrongly.";
|
||||
return *t_;
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
if (t_) {
|
||||
delete t_;
|
||||
t_ = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void Swap(VectorFstTplHolder<Arc> *other) { std::swap(t_, other->t_); }
|
||||
|
||||
bool ExtractRange(const VectorFstTplHolder<Arc> &other,
|
||||
const std::string &range) {
|
||||
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
|
||||
return false;
|
||||
}
|
||||
|
||||
~VectorFstTplHolder() { Clear(); }
|
||||
// No destructor. Assignment and
|
||||
// copy constructor take their default implementations.
|
||||
private:
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(VectorFstTplHolder);
|
||||
T *t_;
|
||||
};
|
||||
|
||||
// Now make the original VectorFstHolder as the typedef of
|
||||
// VectorFstHolder<StdArc>.
|
||||
typedef VectorFstTplHolder<StdArc> VectorFstHolder;
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/kaldi-fst-io-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_KALDI_FST_IO_H_
|
@ -0,0 +1,267 @@
|
||||
// fstext/lattice-utils-inl.h
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
|
||||
// Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|
||||
#define KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|
||||
// Do not include this file directly. It is included by lattice-utils.h
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace fst {
|
||||
|
||||
/* Convert from FST with arc-type Weight, to one with arc-type
|
||||
CompactLatticeWeight. Uses FactorFst to identify chains
|
||||
of states which can be turned into a single output arc. */
|
||||
|
||||
template <class Weight, class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *ofst,
|
||||
bool invert) {
|
||||
typedef ArcTpl<Weight> Arc;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
|
||||
typedef ArcTpl<CompactWeight> CompactArc;
|
||||
|
||||
VectorFst<ArcTpl<Weight> > ffst;
|
||||
std::vector<std::vector<Int> > labels;
|
||||
if (invert) { // normal case: want the ilabels as sequences on the arcs of
|
||||
Factor(ifst, &ffst, &labels); // the output... Factor makes seqs of
|
||||
// ilabels.
|
||||
} else {
|
||||
VectorFst<ArcTpl<Weight> > invfst(ifst);
|
||||
Invert(&invfst);
|
||||
Factor(invfst, &ffst, &labels);
|
||||
}
|
||||
|
||||
TopSort(&ffst); // Put the states in ffst in topological order, which is
|
||||
// easier on the eye when reading the text-form lattices and corresponds to
|
||||
// what we get when we generate the lattices in the decoder.
|
||||
|
||||
ofst->DeleteStates();
|
||||
|
||||
// The states will be numbered exactly the same as the original FST.
|
||||
// Add the states to the new FST.
|
||||
StateId num_states = ffst.NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
StateId news = ofst->AddState();
|
||||
assert(news == s);
|
||||
}
|
||||
ofst->SetStart(ffst.Start());
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
Weight final_weight = ffst.Final(s);
|
||||
if (final_weight != Weight::Zero()) {
|
||||
CompactWeight final_compact_weight(final_weight, std::vector<Int>());
|
||||
ofst->SetFinal(s, final_compact_weight);
|
||||
}
|
||||
for (ArcIterator<ExpandedFst<Arc> > iter(ffst, s); !iter.Done();
|
||||
iter.Next()) {
|
||||
const Arc &arc = iter.Value();
|
||||
KALDI_PARANOID_ASSERT(arc.weight != Weight::Zero());
|
||||
// note: zero-weight arcs not allowed anyway so weight should not be zero,
|
||||
// but no harm in checking.
|
||||
CompactArc compact_arc(arc.olabel, arc.olabel,
|
||||
CompactWeight(arc.weight, labels[arc.ilabel]),
|
||||
arc.nextstate);
|
||||
ofst->AddArc(s, compact_arc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Weight, class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &ifst,
|
||||
MutableFst<ArcTpl<Weight> > *ofst, bool invert) {
|
||||
typedef ArcTpl<Weight> Arc;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef typename Arc::Label Label;
|
||||
typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
|
||||
typedef ArcTpl<CompactWeight> CompactArc;
|
||||
ofst->DeleteStates();
|
||||
// make the states in the new FST have the same numbers as
|
||||
// the original ones, and add chains of states as necessary
|
||||
// to encode the string-valued weights.
|
||||
StateId num_states = ifst.NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
StateId news = ofst->AddState();
|
||||
assert(news == s);
|
||||
}
|
||||
ofst->SetStart(ifst.Start());
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
CompactWeight final_weight = ifst.Final(s);
|
||||
if (final_weight != CompactWeight::Zero()) {
|
||||
StateId cur_state = s;
|
||||
size_t string_length = final_weight.String().size();
|
||||
for (size_t n = 0; n < string_length; n++) {
|
||||
StateId next_state = ofst->AddState();
|
||||
Label ilabel = 0;
|
||||
Arc arc(ilabel, final_weight.String()[n],
|
||||
(n == 0 ? final_weight.Weight() : Weight::One()), next_state);
|
||||
if (invert) std::swap(arc.ilabel, arc.olabel);
|
||||
ofst->AddArc(cur_state, arc);
|
||||
cur_state = next_state;
|
||||
}
|
||||
ofst->SetFinal(cur_state,
|
||||
string_length > 0 ? Weight::One() : final_weight.Weight());
|
||||
}
|
||||
for (ArcIterator<ExpandedFst<CompactArc> > iter(ifst, s); !iter.Done();
|
||||
iter.Next()) {
|
||||
const CompactArc &arc = iter.Value();
|
||||
size_t string_length = arc.weight.String().size();
|
||||
StateId cur_state = s;
|
||||
// for all but the last element in the string--
|
||||
// add a temporary state.
|
||||
for (size_t n = 0; n + 1 < string_length; n++) {
|
||||
StateId next_state = ofst->AddState();
|
||||
Label ilabel = (n == 0 ? arc.ilabel : 0),
|
||||
olabel = static_cast<Label>(arc.weight.String()[n]);
|
||||
Weight weight = (n == 0 ? arc.weight.Weight() : Weight::One());
|
||||
Arc new_arc(ilabel, olabel, weight, next_state);
|
||||
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
|
||||
ofst->AddArc(cur_state, new_arc);
|
||||
cur_state = next_state;
|
||||
}
|
||||
Label ilabel = (string_length <= 1 ? arc.ilabel : 0),
|
||||
olabel = (string_length > 0 ? arc.weight.String()[string_length - 1]
|
||||
: 0);
|
||||
Weight weight =
|
||||
(string_length <= 1 ? arc.weight.Weight() : Weight::One());
|
||||
Arc new_arc(ilabel, olabel, weight, arc.nextstate);
|
||||
if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
|
||||
ofst->AddArc(cur_state, new_arc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function converts lattices between float and double;
|
||||
// it works for both CompactLatticeWeight and LatticeWeight.
|
||||
template <class WeightIn, class WeightOut>
|
||||
void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> > &ifst,
|
||||
MutableFst<ArcTpl<WeightOut> > *ofst) {
|
||||
typedef ArcTpl<WeightIn> ArcIn;
|
||||
typedef ArcTpl<WeightOut> ArcOut;
|
||||
typedef typename ArcIn::StateId StateId;
|
||||
ofst->DeleteStates();
|
||||
// The states will be numbered exactly the same as the original FST.
|
||||
// Add the states to the new FST.
|
||||
StateId num_states = ifst.NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
StateId news = ofst->AddState();
|
||||
assert(news == s);
|
||||
}
|
||||
ofst->SetStart(ifst.Start());
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
WeightIn final_iweight = ifst.Final(s);
|
||||
if (final_iweight != WeightIn::Zero()) {
|
||||
WeightOut final_oweight;
|
||||
ConvertLatticeWeight(final_iweight, &final_oweight);
|
||||
ofst->SetFinal(s, final_oweight);
|
||||
}
|
||||
for (ArcIterator<ExpandedFst<ArcIn> > iter(ifst, s); !iter.Done();
|
||||
iter.Next()) {
|
||||
ArcIn arc = iter.Value();
|
||||
KALDI_PARANOID_ASSERT(arc.weight != WeightIn::Zero());
|
||||
ArcOut oarc;
|
||||
ConvertLatticeWeight(arc.weight, &oarc.weight);
|
||||
oarc.ilabel = arc.ilabel;
|
||||
oarc.olabel = arc.olabel;
|
||||
oarc.nextstate = arc.nextstate;
|
||||
ofst->AddArc(s, oarc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Weight, class ScaleFloat>
|
||||
void ScaleLattice(const std::vector<std::vector<ScaleFloat> > &scale,
|
||||
MutableFst<ArcTpl<Weight> > *fst) {
|
||||
assert(scale.size() == 2 && scale[0].size() == 2 && scale[1].size() == 2);
|
||||
if (scale == DefaultLatticeScale()) // nothing to do.
|
||||
return;
|
||||
typedef ArcTpl<Weight> Arc;
|
||||
typedef MutableFst<Arc> Fst;
|
||||
typedef typename Arc::StateId StateId;
|
||||
StateId num_states = fst->NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||||
Arc arc = aiter.Value();
|
||||
arc.weight = Weight(ScaleTupleWeight(arc.weight, scale));
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
Weight final_weight = fst->Final(s);
|
||||
if (final_weight != Weight::Zero())
|
||||
fst->SetFinal(s, Weight(ScaleTupleWeight(final_weight, scale)));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Weight, class Int>
|
||||
void RemoveAlignmentsFromCompactLattice(
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst) {
|
||||
typedef CompactLatticeWeightTpl<Weight, Int> W;
|
||||
typedef ArcTpl<W> Arc;
|
||||
typedef MutableFst<Arc> Fst;
|
||||
typedef typename Arc::StateId StateId;
|
||||
StateId num_states = fst->NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||||
Arc arc = aiter.Value();
|
||||
arc.weight = W(arc.weight.Weight(), std::vector<Int>());
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
W final_weight = fst->Final(s);
|
||||
if (final_weight != W::Zero())
|
||||
fst->SetFinal(s, W(final_weight.Weight(), std::vector<Int>()));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Weight, class Int>
|
||||
bool CompactLatticeHasAlignment(
|
||||
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &fst) {
|
||||
typedef CompactLatticeWeightTpl<Weight, Int> W;
|
||||
typedef ArcTpl<W> Arc;
|
||||
typedef ExpandedFst<Arc> Fst;
|
||||
typedef typename Arc::StateId StateId;
|
||||
StateId num_states = fst.NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
for (ArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
|
||||
const Arc &arc = aiter.Value();
|
||||
if (!arc.weight.String().empty()) return true;
|
||||
}
|
||||
W final_weight = fst.Final(s);
|
||||
if (!final_weight.String().empty()) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class Real>
|
||||
void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> > &ifst,
|
||||
MutableFst<ArcTpl<LatticeWeightTpl<Real> > > *ofst) {
|
||||
int32 num_states_cache = 50000;
|
||||
fst::CacheOptions cache_opts(true, num_states_cache);
|
||||
fst::MapFstOptions mapfst_opts(cache_opts);
|
||||
StdToLatticeMapper<Real> mapper;
|
||||
MapFst<StdArc, ArcTpl<LatticeWeightTpl<Real> >, StdToLatticeMapper<Real> >
|
||||
map_fst(ifst, mapper, mapfst_opts);
|
||||
*ofst = map_fst;
|
||||
}
|
||||
|
||||
} // namespace fst
|
||||
|
||||
#endif // KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|
@ -0,0 +1,259 @@
|
||||
// fstext/lattice-utils.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_LATTICE_UTILS_H_
|
||||
#define KALDI_FSTEXT_LATTICE_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/lattice-weight.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
// The template ConvertLattice does conversions to and from
|
||||
// LatticeWeight FSTs and CompactLatticeWeight FSTs, and
|
||||
// between float and double, and to convert from LatticeWeight
|
||||
// to TropicalWeight. It's used in the I/O code for lattices,
|
||||
// and for converting lattices to standard FSTs (e.g. for creating
|
||||
// decoding graphs from lattices).
|
||||
|
||||
/**
|
||||
Convert lattice from a normal FST to a CompactLattice FST.
|
||||
This is a bit like converting to the Gallic semiring, except
|
||||
the semiring behaves in a different way (designed to take
|
||||
the best path).
|
||||
Note: the ilabels end up as the symbols on the arcs of the
|
||||
output acceptor, and the olabels go to the strings. To make
|
||||
it the other way around (useful for the speech-recognition
|
||||
application), set invert=true [the default].
|
||||
*/
|
||||
template <class Weight, class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<Weight> > &ifst,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *ofst,
|
||||
bool invert = true);
|
||||
|
||||
/**
|
||||
Convert lattice CompactLattice format to Lattice. This is a bit
|
||||
like converting from the Gallic semiring. As for any CompactLattice, "ifst"
|
||||
must be an acceptor (i.e., ilabels and olabels should be identical). If
|
||||
invert=false, the labels on "ifst" become the ilabels on "ofst" and the
|
||||
strings in the weights of "ifst" becomes the olabels. If invert=true
|
||||
[default], this is reversed (useful for speech recognition lattices; our
|
||||
standard non-compact format has the words on the output side to match HCLG).
|
||||
*/
|
||||
template <class Weight, class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &ifst,
|
||||
MutableFst<ArcTpl<Weight> > *ofst, bool invert = true);
|
||||
|
||||
/**
|
||||
Convert between CompactLattices and Lattices of different floating point
|
||||
types... this works between any pair of weight types for which
|
||||
ConvertLatticeWeight is defined (c.f. lattice-weight.h), and also includes
|
||||
conversion from LatticeWeight to TropicalWeight.
|
||||
*/
|
||||
template <class WeightIn, class WeightOut>
|
||||
void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> > &ifst,
|
||||
MutableFst<ArcTpl<WeightOut> > *ofst);
|
||||
|
||||
// Now define some ConvertLattice functions that require two phases of
|
||||
// conversion (don't bother coding these separately as they will be used rarely.
|
||||
|
||||
// Lattice with float to CompactLattice with double.
|
||||
template <class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<LatticeWeightTpl<float> > > &ifst,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
|
||||
*ofst) {
|
||||
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
|
||||
fst;
|
||||
ConvertLattice(ifst, &fst);
|
||||
ConvertLattice(fst, ofst);
|
||||
}
|
||||
|
||||
// Lattice with double to CompactLattice with float.
|
||||
template <class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<ArcTpl<LatticeWeightTpl<double> > > &ifst,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
|
||||
*ofst) {
|
||||
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
|
||||
fst;
|
||||
ConvertLattice(ifst, &fst);
|
||||
ConvertLattice(fst, ofst);
|
||||
}
|
||||
|
||||
/// Converts CompactLattice with double to Lattice with float.
|
||||
template <class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<
|
||||
ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > > &ifst,
|
||||
MutableFst<ArcTpl<LatticeWeightTpl<float> > > *ofst) {
|
||||
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > >
|
||||
fst;
|
||||
ConvertLattice(ifst, &fst);
|
||||
ConvertLattice(fst, ofst);
|
||||
}
|
||||
|
||||
/// Converts CompactLattice with float to Lattice with double.
|
||||
template <class Int>
|
||||
void ConvertLattice(
|
||||
const ExpandedFst<
|
||||
ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<float>, Int> > > &ifst,
|
||||
MutableFst<ArcTpl<LatticeWeightTpl<double> > > *ofst) {
|
||||
VectorFst<ArcTpl<CompactLatticeWeightTpl<LatticeWeightTpl<double>, Int> > >
|
||||
fst;
|
||||
ConvertLattice(ifst, &fst);
|
||||
ConvertLattice(fst, ofst);
|
||||
}
|
||||
|
||||
/// Converts TropicalWeight to LatticeWeight (puts all the weight on
|
||||
/// the first float in the lattice's pair).
|
||||
template <class Real>
|
||||
void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> > &ifst,
|
||||
MutableFst<ArcTpl<LatticeWeightTpl<Real> > > *ofst);
|
||||
|
||||
/** Returns a default 2x2 matrix scaling factor for LatticeWeight */
|
||||
inline std::vector<std::vector<double> > DefaultLatticeScale() {
|
||||
std::vector<std::vector<double> > ans(2);
|
||||
ans[0].resize(2, 0.0);
|
||||
ans[1].resize(2, 0.0);
|
||||
ans[0][0] = ans[1][1] = 1.0;
|
||||
return ans;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<double> > AcousticLatticeScale(double acwt) {
|
||||
std::vector<std::vector<double> > ans(2);
|
||||
ans[0].resize(2, 0.0);
|
||||
ans[1].resize(2, 0.0);
|
||||
ans[0][0] = 1.0;
|
||||
ans[1][1] = acwt;
|
||||
return ans;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<double> > GraphLatticeScale(double lmwt) {
|
||||
std::vector<std::vector<double> > ans(2);
|
||||
ans[0].resize(2, 0.0);
|
||||
ans[1].resize(2, 0.0);
|
||||
ans[0][0] = lmwt;
|
||||
ans[1][1] = 1.0;
|
||||
return ans;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<double> > LatticeScale(double lmwt,
|
||||
double acwt) {
|
||||
std::vector<std::vector<double> > ans(2);
|
||||
ans[0].resize(2, 0.0);
|
||||
ans[1].resize(2, 0.0);
|
||||
ans[0][0] = lmwt;
|
||||
ans[1][1] = acwt;
|
||||
return ans;
|
||||
}
|
||||
|
||||
/** Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by
|
||||
viewing the pair (a, b) as a 2-vector and pre-multiplying by the 2x2 matrix
|
||||
in "scale". E.g. typically scale would equal
|
||||
[ 1 0;
|
||||
0 acwt ]
|
||||
if we want to scale the acoustics by "acwt".
|
||||
*/
|
||||
template <class Weight, class ScaleFloat>
|
||||
void ScaleLattice(const std::vector<std::vector<ScaleFloat> > &scale,
|
||||
MutableFst<ArcTpl<Weight> > *fst);
|
||||
|
||||
/// Removes state-level alignments (the strings that are
|
||||
/// part of the weights).
|
||||
template <class Weight, class Int>
|
||||
void RemoveAlignmentsFromCompactLattice(
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst);
|
||||
|
||||
/// Returns true if lattice has alignments, i.e. it has
|
||||
/// any nonempty strings inside its weights.
|
||||
template <class Weight, class Int>
|
||||
bool CompactLatticeHasAlignment(
|
||||
const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > &fst);
|
||||
|
||||
/// Class StdToLatticeMapper maps a normal arc (StdArc)
|
||||
/// to a LatticeArc by putting the StdArc weight as the first
|
||||
/// element of the LatticeWeight. Useful when doing LM
|
||||
/// rescoring.
|
||||
template <class Real>
|
||||
class StdToLatticeMapper {
|
||||
typedef LatticeWeightTpl<Real> LatticeWeight;
|
||||
typedef ArcTpl<LatticeWeight> LatticeArc;
|
||||
|
||||
public:
|
||||
LatticeArc operator()(const StdArc &arc) {
|
||||
// Note: we have to check whether the arc's weight is zero below,
|
||||
// and if so return (infinity, infinity) and not (infinity, zero),
|
||||
// because (infinity, zero) is not a valid LatticeWeight, which should
|
||||
// either be both finite, or both infinite (i.e. Zero()).
|
||||
return LatticeArc(
|
||||
arc.ilabel, arc.olabel,
|
||||
LatticeWeight(arc.weight.Value(), arc.weight == StdArc::Weight::Zero()
|
||||
? arc.weight.Value()
|
||||
: 0.0),
|
||||
arc.nextstate);
|
||||
}
|
||||
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
|
||||
|
||||
MapSymbolsAction InputSymbolsAction() { return MAP_COPY_SYMBOLS; }
|
||||
|
||||
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
|
||||
|
||||
// I believe all properties are preserved.
|
||||
uint64 Properties(uint64 props) { return props; }
|
||||
};
|
||||
|
||||
/// Class LatticeToStdMapper maps a LatticeArc to a normal arc (StdArc)
|
||||
/// by adding the elements of the LatticeArc weight.
|
||||
|
||||
template <class Real>
|
||||
class LatticeToStdMapper {
|
||||
typedef LatticeWeightTpl<Real> LatticeWeight;
|
||||
typedef ArcTpl<LatticeWeight> LatticeArc;
|
||||
|
||||
public:
|
||||
StdArc operator()(const LatticeArc &arc) {
|
||||
return StdArc(arc.ilabel, arc.olabel,
|
||||
StdArc::Weight(arc.weight.Value1() + arc.weight.Value2()),
|
||||
arc.nextstate);
|
||||
}
|
||||
MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
|
||||
|
||||
MapSymbolsAction InputSymbolsAction() { return MAP_COPY_SYMBOLS; }
|
||||
|
||||
MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
|
||||
|
||||
// I believe all properties are preserved.
|
||||
uint64 Properties(uint64 props) { return props; }
|
||||
};
|
||||
|
||||
template <class Weight, class Int>
|
||||
void PruneCompactLattice(
|
||||
Weight beam,
|
||||
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > > *fst);
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/lattice-utils-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_LATTICE_UTILS_H_
|
@ -0,0 +1,892 @@
|
||||
// fstext/lattice-weight.h
|
||||
// Copyright 2009-2012 Microsoft Corporation
|
||||
// Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
|
||||
#define KALDI_FSTEXT_LATTICE_WEIGHT_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
// Declare weight type for lattice... will import to namespace kaldi. has two
|
||||
// members, value1_ and value2_, of type BaseFloat (normally equals float). It
|
||||
// is basically the same as the tropical semiring on value1_+value2_, except it
|
||||
// keeps track of a and b separately. More precisely, it is equivalent to the
|
||||
// lexicographic semiring on (value1_+value2_), (value1_-value2_)
|
||||
|
||||
template <class FloatType>
|
||||
class LatticeWeightTpl;
|
||||
|
||||
template <class FloatType>
|
||||
inline std::ostream &operator<<(std::ostream &strm,
|
||||
const LatticeWeightTpl<FloatType> &w);
|
||||
|
||||
template <class FloatType>
|
||||
inline std::istream &operator>>(std::istream &strm,
|
||||
LatticeWeightTpl<FloatType> &w);
|
||||
|
||||
template <class FloatType>
|
||||
class LatticeWeightTpl {
|
||||
public:
|
||||
typedef FloatType T; // normally float.
|
||||
typedef LatticeWeightTpl ReverseWeight;
|
||||
|
||||
inline T Value1() const { return value1_; }
|
||||
|
||||
inline T Value2() const { return value2_; }
|
||||
|
||||
inline void SetValue1(T f) { value1_ = f; }
|
||||
|
||||
inline void SetValue2(T f) { value2_ = f; }
|
||||
|
||||
LatticeWeightTpl() : value1_{}, value2_{} {}
|
||||
|
||||
LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
|
||||
|
||||
LatticeWeightTpl(const LatticeWeightTpl &other)
|
||||
: value1_(other.value1_), value2_(other.value2_) {}
|
||||
|
||||
LatticeWeightTpl &operator=(const LatticeWeightTpl &w) {
|
||||
value1_ = w.value1_;
|
||||
value2_ = w.value2_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LatticeWeightTpl<FloatType> Reverse() const { return *this; }
|
||||
|
||||
static const LatticeWeightTpl Zero() {
|
||||
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
|
||||
std::numeric_limits<T>::infinity());
|
||||
}
|
||||
|
||||
static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
|
||||
|
||||
static const std::string &Type() {
|
||||
static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
|
||||
return type;
|
||||
}
|
||||
|
||||
static const LatticeWeightTpl NoWeight() {
|
||||
return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
|
||||
std::numeric_limits<FloatType>::quiet_NaN());
|
||||
}
|
||||
|
||||
bool Member() const {
|
||||
// value1_ == value1_ tests for NaN.
|
||||
// also test for no -inf, and either both or neither
|
||||
// must be +inf, and
|
||||
if (value1_ != value1_ || value2_ != value2_) return false; // NaN
|
||||
if (value1_ == -std::numeric_limits<T>::infinity() ||
|
||||
value2_ == -std::numeric_limits<T>::infinity())
|
||||
return false; // -infty not allowed
|
||||
if (value1_ == std::numeric_limits<T>::infinity() ||
|
||||
value2_ == std::numeric_limits<T>::infinity()) {
|
||||
if (value1_ != std::numeric_limits<T>::infinity() ||
|
||||
value2_ != std::numeric_limits<T>::infinity())
|
||||
return false; // both must be +infty;
|
||||
// this is necessary so that the semiring has only one zero.
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
LatticeWeightTpl Quantize(float delta = kDelta) const {
|
||||
if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
|
||||
return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
|
||||
-std::numeric_limits<T>::infinity());
|
||||
} else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
|
||||
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
|
||||
std::numeric_limits<T>::infinity());
|
||||
} else if (value1_ + value2_ != value1_ + value2_) { // NaN
|
||||
return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
|
||||
} else {
|
||||
return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
|
||||
floor(value2_ / delta + 0.5F) * delta);
|
||||
}
|
||||
}
|
||||
static constexpr uint64 Properties() {
|
||||
return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
|
||||
}
|
||||
|
||||
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
||||
// not Kaldi-style, I/O.
|
||||
std::istream &Read(std::istream &strm) {
|
||||
// Always read/write as float, even if T is double,
|
||||
// so we can use OpenFst-style read/write and still maintain
|
||||
// compatibility when compiling with different FloatTypes
|
||||
ReadType(strm, &value1_);
|
||||
ReadType(strm, &value2_);
|
||||
return strm;
|
||||
}
|
||||
|
||||
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
||||
// not Kaldi-style, I/O.
|
||||
std::ostream &Write(std::ostream &strm) const {
|
||||
WriteType(strm, value1_);
|
||||
WriteType(strm, value2_);
|
||||
return strm;
|
||||
}
|
||||
|
||||
size_t Hash() const {
|
||||
size_t ans;
|
||||
union {
|
||||
T f;
|
||||
size_t s;
|
||||
} u;
|
||||
u.s = 0;
|
||||
u.f = value1_;
|
||||
ans = u.s;
|
||||
u.f = value2_;
|
||||
ans += u.s;
|
||||
return ans;
|
||||
}
|
||||
|
||||
protected:
|
||||
inline static void WriteFloatType(std::ostream &strm, const T &f) {
|
||||
if (f == std::numeric_limits<T>::infinity())
|
||||
strm << "Infinity";
|
||||
else if (f == -std::numeric_limits<T>::infinity())
|
||||
strm << "-Infinity";
|
||||
else if (f != f)
|
||||
strm << "BadNumber";
|
||||
else
|
||||
strm << f;
|
||||
}
|
||||
|
||||
// Internal helper function, used in ReadNoParen.
|
||||
inline static void ReadFloatType(std::istream &strm, T &f) { // NOLINT
|
||||
std::string s;
|
||||
strm >> s;
|
||||
if (s == "Infinity") {
|
||||
f = std::numeric_limits<T>::infinity();
|
||||
} else if (s == "-Infinity") {
|
||||
f = -std::numeric_limits<T>::infinity();
|
||||
} else if (s == "BadNumber") {
|
||||
f = std::numeric_limits<T>::quiet_NaN();
|
||||
} else {
|
||||
char *p;
|
||||
f = strtod(s.c_str(), &p);
|
||||
if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
|
||||
}
|
||||
}
|
||||
|
||||
// Reads LatticeWeight when there are no parentheses around pair terms...
|
||||
// currently the only form supported.
|
||||
inline std::istream &ReadNoParen(std::istream &strm, char separator) {
|
||||
int c;
|
||||
do {
|
||||
c = strm.get();
|
||||
} while (isspace(c));
|
||||
|
||||
std::string s1;
|
||||
while (c != separator) {
|
||||
if (c == EOF) {
|
||||
strm.clear(std::ios::badbit);
|
||||
return strm;
|
||||
}
|
||||
s1 += c;
|
||||
c = strm.get();
|
||||
}
|
||||
std::istringstream strm1(s1);
|
||||
ReadFloatType(strm1, value1_); // ReadFloatType is class member function
|
||||
// read second element
|
||||
ReadFloatType(strm, value2_);
|
||||
return strm;
|
||||
}
|
||||
|
||||
friend std::istream &operator>>
|
||||
<FloatType>(std::istream &, LatticeWeightTpl<FloatType> &);
|
||||
friend std::ostream &operator<<<FloatType>(
|
||||
std::ostream &, const LatticeWeightTpl<FloatType> &);
|
||||
|
||||
private:
|
||||
T value1_;
|
||||
T value2_;
|
||||
};
|
||||
|
||||
/* ScaleTupleWeight is a function defined for LatticeWeightTpl and
|
||||
CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
|
||||
matrix. Used, for example, in applying acoustic scaling.
|
||||
*/
|
||||
template <class FloatType, class ScaleFloatType>
|
||||
inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
|
||||
const LatticeWeightTpl<FloatType> &w,
|
||||
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
||||
// Without the next special case we'd get NaNs from infinity * 0
|
||||
if (w.Value1() == std::numeric_limits<FloatType>::infinity())
|
||||
return LatticeWeightTpl<FloatType>::Zero();
|
||||
return LatticeWeightTpl<FloatType>(
|
||||
scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
|
||||
scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
|
||||
}
|
||||
|
||||
/* For testing purposes and in case it's ever useful, we define a similar
|
||||
function to apply to LexicographicWeight and the like, templated on
|
||||
TropicalWeight<float> etc.; we use PairWeight which is the base class of
|
||||
LexicographicWeight.
|
||||
*/
|
||||
template <class FloatType, class ScaleFloatType>
|
||||
inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
|
||||
ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
|
||||
TropicalWeightTpl<FloatType> > &w,
|
||||
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
||||
typedef TropicalWeightTpl<FloatType> BaseType;
|
||||
typedef PairWeight<BaseType, BaseType> PairType;
|
||||
const BaseType zero = BaseType::Zero();
|
||||
// Without the next special case we'd get NaNs from infinity * 0
|
||||
if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
|
||||
FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
|
||||
return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
|
||||
BaseType(scale[1][0] * f1 + scale[1][1] * f2));
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline bool operator==(const LatticeWeightTpl<FloatType> &wa,
|
||||
const LatticeWeightTpl<FloatType> &wb) {
|
||||
// Volatile qualifier thwarts over-aggressive compiler optimizations
|
||||
// that lead to problems esp. with NaturalLess().
|
||||
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
|
||||
vb2 = wb.Value2();
|
||||
return (va1 == vb1 && va2 == vb2);
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline bool operator!=(const LatticeWeightTpl<FloatType> &wa,
|
||||
const LatticeWeightTpl<FloatType> &wb) {
|
||||
// Volatile qualifier thwarts over-aggressive compiler optimizations
|
||||
// that lead to problems esp. with NaturalLess().
|
||||
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
|
||||
vb2 = wb.Value2();
|
||||
return (va1 != vb1 || va2 != vb2);
|
||||
}
|
||||
|
||||
// We define a Compare function LatticeWeightTpl even though it's
|
||||
// not required by the semiring standard-- it's just more efficient
|
||||
// to do it this way rather than using the NaturalLess template.
|
||||
|
||||
/// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
|
||||
|
||||
template <class FloatType>
|
||||
inline int Compare(const LatticeWeightTpl<FloatType> &w1,
|
||||
const LatticeWeightTpl<FloatType> &w2) {
|
||||
FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
|
||||
if (f1 < f2) { // having smaller cost means you're larger
|
||||
return 1;
|
||||
} else if (f1 > f2) { // in the semiring [higher probability]
|
||||
return -1;
|
||||
} else if (w1.Value1() < w2.Value1()) {
|
||||
// mathematically we should be comparing (w1.value1_-w1.value2_ <
|
||||
// w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
|
||||
// w2.value1_+w2.value2_ to both sides and divide by two, and we get the
|
||||
// simpler equivalent form w1.value1_ < w2.value1_.
|
||||
return 1;
|
||||
} else if (w1.Value1() > w2.Value1()) {
|
||||
return -1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType> &w1,
|
||||
const LatticeWeightTpl<FloatType> &w2) {
|
||||
return (Compare(w1, w2) >= 0 ? w1 : w2);
|
||||
}
|
||||
|
||||
// For efficiency, override the NaturalLess template class.
|
||||
template <class FloatType>
|
||||
class NaturalLess<LatticeWeightTpl<FloatType> > {
|
||||
public:
|
||||
typedef LatticeWeightTpl<FloatType> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
class NaturalLess<LatticeWeightTpl<float> > {
|
||||
public:
|
||||
typedef LatticeWeightTpl<float> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
class NaturalLess<LatticeWeightTpl<double> > {
|
||||
public:
|
||||
typedef LatticeWeightTpl<double> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <class FloatType>
|
||||
inline LatticeWeightTpl<FloatType> Times(
|
||||
const LatticeWeightTpl<FloatType> &w1,
|
||||
const LatticeWeightTpl<FloatType> &w2) {
|
||||
return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
|
||||
w1.Value2() + w2.Value2());
|
||||
}
|
||||
|
||||
// divide w1 by w2 (on left/right/any doesn't matter as
|
||||
// commutative).
|
||||
template <class FloatType>
|
||||
inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType> &w1,
|
||||
const LatticeWeightTpl<FloatType> &w2,
|
||||
DivideType typ = DIVIDE_ANY) {
|
||||
typedef FloatType T;
|
||||
T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
|
||||
if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
|
||||
b == -std::numeric_limits<T>::infinity()) {
|
||||
KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
|
||||
<< "[dividing by zero?] Returning zero";
|
||||
return LatticeWeightTpl<T>::Zero();
|
||||
}
|
||||
if (a == std::numeric_limits<T>::infinity() ||
|
||||
b == std::numeric_limits<T>::infinity())
|
||||
return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
|
||||
// infinite.
|
||||
return LatticeWeightTpl<T>(a, b);
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline bool ApproxEqual(const LatticeWeightTpl<FloatType> &w1,
|
||||
const LatticeWeightTpl<FloatType> &w2,
|
||||
float delta = kDelta) {
|
||||
if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
|
||||
return true; // handles Zero().
|
||||
return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
|
||||
delta);
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline std::ostream &operator<<(std::ostream &strm,
|
||||
const LatticeWeightTpl<FloatType> &w) {
|
||||
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
|
||||
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
||||
strm << FLAGS_fst_weight_separator[0]; // comma by default;
|
||||
// may or may not be settable from Kaldi programs.
|
||||
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
|
||||
return strm;
|
||||
}
|
||||
|
||||
template <class FloatType>
|
||||
inline std::istream &operator>>(std::istream &strm,
|
||||
LatticeWeightTpl<FloatType> &w1) {
|
||||
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
||||
// separator defaults to ','
|
||||
return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
|
||||
}
|
||||
|
||||
// CompactLattice will be an acceptor (accepting the words/output-symbols),
|
||||
// with the weights and input-symbol-seqs on the arcs.
|
||||
// There must be a total order on W. We assume for the sake of efficiency
|
||||
// that there is a function
|
||||
// Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
|
||||
// zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
|
||||
// w2).
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
class CompactLatticeWeightTpl {
|
||||
public:
|
||||
typedef WeightType W;
|
||||
|
||||
typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
|
||||
|
||||
// Plus is like LexicographicWeight on the pair (weight_, string_), but where
|
||||
// we use standard lexicographic order on string_ [this is not the same as
|
||||
// NaturalLess on the StringWeight equivalent, which does not define a
|
||||
// total order].
|
||||
// Times, Divide obvious... (support both left & right division..)
|
||||
// CommonDivisor would need to be coded separately.
|
||||
|
||||
CompactLatticeWeightTpl() {}
|
||||
|
||||
CompactLatticeWeightTpl(const WeightType &w, const std::vector<IntType> &s)
|
||||
: weight_(w), string_(s) {}
|
||||
|
||||
CompactLatticeWeightTpl &operator=(
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
||||
weight_ = w.weight_;
|
||||
string_ = w.string_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const W &Weight() const { return weight_; }
|
||||
|
||||
const std::vector<IntType> &String() const { return string_; }
|
||||
|
||||
void SetWeight(const W &w) { weight_ = w; }
|
||||
|
||||
void SetString(const std::vector<IntType> &s) { string_ = s; }
|
||||
|
||||
static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
|
||||
std::vector<IntType>());
|
||||
}
|
||||
|
||||
static const CompactLatticeWeightTpl<WeightType, IntType> One() {
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
|
||||
std::vector<IntType>());
|
||||
}
|
||||
|
||||
inline static std::string GetIntSizeString() {
|
||||
char buf[2];
|
||||
buf[0] = '0' + sizeof(IntType);
|
||||
buf[1] = '\0';
|
||||
return buf;
|
||||
}
|
||||
static const std::string &Type() {
|
||||
static const std::string type =
|
||||
"compact" + WeightType::Type() + GetIntSizeString();
|
||||
return type;
|
||||
}
|
||||
|
||||
static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
|
||||
std::vector<IntType>());
|
||||
}
|
||||
|
||||
CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
|
||||
size_t s = string_.size();
|
||||
std::vector<IntType> v(s);
|
||||
for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
|
||||
}
|
||||
|
||||
bool Member() const {
|
||||
// a semiring has only one zero, this is the important property
|
||||
// we're trying to maintain here. So force string_ to be empty if
|
||||
// w_ == zero.
|
||||
if (!weight_.Member()) return false;
|
||||
if (weight_ == WeightType::Zero())
|
||||
return string_.empty();
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
|
||||
return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
|
||||
}
|
||||
|
||||
static constexpr uint64 Properties() {
|
||||
return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
|
||||
}
|
||||
|
||||
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
||||
// not Kaldi-style, I/O.
|
||||
std::istream &Read(std::istream &strm) {
|
||||
weight_.Read(strm);
|
||||
if (strm.fail()) {
|
||||
return strm;
|
||||
}
|
||||
int32 sz;
|
||||
ReadType(strm, &sz);
|
||||
if (strm.fail()) {
|
||||
return strm;
|
||||
}
|
||||
if (sz < 0) {
|
||||
KALDI_WARN << "Negative string size! Read failure";
|
||||
strm.clear(std::ios::badbit);
|
||||
return strm;
|
||||
}
|
||||
string_.resize(sz);
|
||||
for (int32 i = 0; i < sz; i++) {
|
||||
ReadType(strm, &(string_[i]));
|
||||
}
|
||||
return strm;
|
||||
}
|
||||
|
||||
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
||||
// not Kaldi-style, I/O.
|
||||
std::ostream &Write(std::ostream &strm) const {
|
||||
weight_.Write(strm);
|
||||
if (strm.fail()) {
|
||||
return strm;
|
||||
}
|
||||
int32 sz = static_cast<int32>(string_.size());
|
||||
WriteType(strm, sz);
|
||||
for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
|
||||
return strm;
|
||||
}
|
||||
size_t Hash() const {
|
||||
size_t ans = weight_.Hash();
|
||||
// any weird numbers here are largish primes
|
||||
size_t sz = string_.size(), mult = 6967;
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
ans += string_[i] * mult;
|
||||
mult *= 7499;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
W weight_;
|
||||
std::vector<IntType> string_;
|
||||
};
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
||||
return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
||||
return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
|
||||
float delta = kDelta) {
|
||||
return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
|
||||
w1.String() == w2.String());
|
||||
}
|
||||
|
||||
// Compare is not part of the standard for weight types, but used internally for
|
||||
// efficiency. The comparison here first compares the weight; if this is the
|
||||
// same, it compares the string. The comparison on strings is: first compare
|
||||
// the length, if this is the same, use lexicographical order. We can't just
|
||||
// use the lexicographical order because this would destroy the distributive
|
||||
// property of multiplication over addition, taking into account that addition
|
||||
// uses Compare. The string element of "Compare" isn't super-important in
|
||||
// practical terms; it's only needed to ensure that Plus always give consistent
|
||||
// answers and is symmetric. It's essentially for tie-breaking, but we need to
|
||||
// make sure all the semiring axioms are satisfied otherwise OpenFst might
|
||||
// break.
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
||||
int c1 = Compare(w1.Weight(), w2.Weight());
|
||||
if (c1 != 0) return c1;
|
||||
int l1 = w1.String().size(), l2 = w2.String().size();
|
||||
// Use opposite order on the string lengths, so that if the costs are the
|
||||
// same, the shorter string wins.
|
||||
if (l1 > l2)
|
||||
return -1;
|
||||
else if (l1 < l2)
|
||||
return 1;
|
||||
for (int i = 0; i < l1; i++) {
|
||||
if (w1.String()[i] < w2.String()[i])
|
||||
return -1;
|
||||
else if (w1.String()[i] > w2.String()[i])
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// For efficiency, override the NaturalLess template class.
|
||||
template <class FloatType, class IntType>
|
||||
class NaturalLess<
|
||||
CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
|
||||
public:
|
||||
typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
|
||||
public:
|
||||
typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
template <>
|
||||
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
|
||||
public:
|
||||
typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
|
||||
|
||||
NaturalLess() {}
|
||||
|
||||
bool operator()(const Weight &w1, const Weight &w2) const {
|
||||
// NaturalLess is a negative order (opposite to normal ordering).
|
||||
// This operator () corresponds to "<" in the negative order, which
|
||||
// corresponds to the ">" in the normal order.
|
||||
return (Compare(w1, w2) == 1);
|
||||
}
|
||||
};
|
||||
|
||||
// Make sure Compare is defined for TropicalWeight, so everything works
|
||||
// if we substitute LatticeWeight for TropicalWeight.
|
||||
inline int Compare(const TropicalWeight &w1, const TropicalWeight &w2) {
|
||||
float f1 = w1.Value(), f2 = w2.Value();
|
||||
if (f1 == f2)
|
||||
return 0;
|
||||
else if (f1 > f2)
|
||||
return -1;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
||||
return (Compare(w1, w2) >= 0 ? w1 : w2);
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline CompactLatticeWeightTpl<WeightType, IntType> Times(
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
||||
WeightType w = Times(w1.Weight(), w2.Weight());
|
||||
if (w == WeightType::Zero()) {
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
|
||||
// special case to ensure zero is unique
|
||||
} else {
|
||||
std::vector<IntType> v;
|
||||
v.resize(w1.String().size() + w2.String().size());
|
||||
typename std::vector<IntType>::iterator iter = v.begin();
|
||||
iter = std::copy(w1.String().begin(), w1.String().end(),
|
||||
iter); // returns end of first range.
|
||||
std::copy(w2.String().begin(), w2.String().end(), iter);
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
|
||||
}
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
||||
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
|
||||
DivideType div = DIVIDE_ANY) {
|
||||
if (w1.Weight() == WeightType::Zero()) {
|
||||
if (w2.Weight() != WeightType::Zero()) {
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
|
||||
} else {
|
||||
KALDI_ERR << "Division by zero [0/0]";
|
||||
}
|
||||
} else if (w2.Weight() == WeightType::Zero()) {
|
||||
KALDI_ERR << "Error: division by zero";
|
||||
}
|
||||
WeightType w = Divide(w1.Weight(), w2.Weight());
|
||||
|
||||
const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
|
||||
if (v2.size() > v1.size()) {
|
||||
KALDI_ERR << "Cannot divide, length mismatch";
|
||||
}
|
||||
typename std::vector<IntType>::const_iterator v1b = v1.begin(),
|
||||
v1e = v1.end(),
|
||||
v2b = v2.begin(),
|
||||
v2e = v2.end();
|
||||
if (div == DIVIDE_LEFT) {
|
||||
if (!std::equal(v2b, v2e,
|
||||
v1b)) { // v2 must be identical to first part of v1.
|
||||
KALDI_ERR << "Cannot divide, data mismatch";
|
||||
}
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(
|
||||
w, std::vector<IntType>(v1b + (v2e - v2b),
|
||||
v1e)); // return last part of v1.
|
||||
} else if (div == DIVIDE_RIGHT) {
|
||||
if (!std::equal(
|
||||
v2b, v2e,
|
||||
v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
|
||||
KALDI_ERR << "Cannot divide, data mismatch";
|
||||
}
|
||||
return CompactLatticeWeightTpl<WeightType, IntType>(
|
||||
w, std::vector<IntType>(
|
||||
v1b, v1e - (v2e - v2b))); // return first part of v1.
|
||||
|
||||
} else {
|
||||
KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
|
||||
}
|
||||
return CompactLatticeWeightTpl<WeightType,
|
||||
IntType>::Zero(); // keep compiler happy.
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &strm, const CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
||||
strm << w.Weight();
|
||||
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
||||
strm << FLAGS_fst_weight_separator[0]; // comma by default.
|
||||
for (size_t i = 0; i < w.String().size(); i++) {
|
||||
strm << w.String()[i];
|
||||
if (i + 1 < w.String().size())
|
||||
strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
|
||||
// code.
|
||||
}
|
||||
return strm;
|
||||
}
|
||||
|
||||
template <class WeightType, class IntType>
|
||||
inline std::istream &operator>>(
|
||||
std::istream &strm, CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
||||
std::string s;
|
||||
strm >> s;
|
||||
if (strm.fail()) {
|
||||
return strm;
|
||||
}
|
||||
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
||||
size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
|
||||
if (pos == std::string::npos) {
|
||||
strm.clear(std::ios::badbit);
|
||||
return strm;
|
||||
}
|
||||
// get parts of str before and after the separator (default: ',');
|
||||
std::string s1(s, 0, pos), s2(s, pos + 1);
|
||||
std::istringstream strm1(s1);
|
||||
WeightType weight;
|
||||
strm1 >> weight;
|
||||
w.SetWeight(weight);
|
||||
if (strm1.fail() || !strm1.eof()) {
|
||||
strm.clear(std::ios::badbit);
|
||||
return strm;
|
||||
}
|
||||
// read string part.
|
||||
std::vector<IntType> string;
|
||||
const char *c = s2.c_str();
|
||||
while (*c != '\0') {
|
||||
if (*c == kStringSeparator) // '_'
|
||||
c++;
|
||||
char *c2;
|
||||
int64_t i = strtol(c, &c2, 10);
|
||||
if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
|
||||
strm.clear(std::ios::badbit);
|
||||
return strm;
|
||||
}
|
||||
c = c2;
|
||||
string.push_back(static_cast<IntType>(i));
|
||||
}
|
||||
w.SetString(string);
|
||||
return strm;
|
||||
}
|
||||
|
||||
template <class BaseWeightType, class IntType>
|
||||
class CompactLatticeWeightCommonDivisorTpl {
|
||||
public:
|
||||
typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
|
||||
|
||||
Weight operator()(const Weight &w1, const Weight &w2) const {
|
||||
// First find longest common prefix of the strings.
|
||||
typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
|
||||
s1e = w1.String().end(),
|
||||
s2b = w2.String().begin(),
|
||||
s2e = w2.String().end();
|
||||
while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
|
||||
s1b++;
|
||||
s2b++;
|
||||
}
|
||||
return Weight(Plus(w1.Weight(), w2.Weight()),
|
||||
std::vector<IntType>(w1.String().begin(), s1b));
|
||||
}
|
||||
};
|
||||
|
||||
/** Scales the pair (a, b) of floating-point weights inside a
|
||||
CompactLatticeWeight by premultiplying it (viewed as a vector)
|
||||
by a 2x2 matrix "scale".
|
||||
Assumes there is a ScaleTupleWeight function that applies to "Weight";
|
||||
this currently only works if Weight equals LatticeWeightTpl<FloatType>
|
||||
for some FloatType.
|
||||
*/
|
||||
template <class Weight, class IntType, class ScaleFloatType>
|
||||
inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
|
||||
const CompactLatticeWeightTpl<Weight, IntType> &w,
|
||||
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
||||
return CompactLatticeWeightTpl<Weight, IntType>(
|
||||
Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
|
||||
}
|
||||
|
||||
/** Define some ConvertLatticeWeight functions that are used in various lattice
|
||||
conversions... make them all templates, some with no arguments, since some
|
||||
must be templates.*/
|
||||
template <class Float1, class Float2>
|
||||
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
|
||||
LatticeWeightTpl<Float2> *w_out) {
|
||||
w_out->SetValue1(w_in.Value1());
|
||||
w_out->SetValue2(w_in.Value2());
|
||||
}
|
||||
|
||||
template <class Float1, class Float2, class Int>
|
||||
inline void ConvertLatticeWeight(
|
||||
const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int> &w_in,
|
||||
CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int> *w_out) {
|
||||
LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
|
||||
w_in.Weight().Value2());
|
||||
w_out->SetWeight(weight2);
|
||||
w_out->SetString(w_in.String());
|
||||
}
|
||||
|
||||
// to convert from Lattice to standard FST
|
||||
template <class Float1, class Float2>
|
||||
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
|
||||
TropicalWeightTpl<Float2> *w_out) {
|
||||
TropicalWeightTpl<Float2> w1(w_in.Value1());
|
||||
TropicalWeightTpl<Float2> w2(w_in.Value2());
|
||||
*w_out = Times(w1, w2);
|
||||
}
|
||||
|
||||
template <class Float>
|
||||
inline double ConvertToCost(const LatticeWeightTpl<Float> &w) {
|
||||
return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
|
||||
}
|
||||
|
||||
template <class Float, class Int>
|
||||
inline double ConvertToCost(
|
||||
const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int> &w) {
|
||||
return static_cast<double>(w.Weight().Value1()) +
|
||||
static_cast<double>(w.Weight().Value2());
|
||||
}
|
||||
|
||||
template <class Float>
|
||||
inline double ConvertToCost(const TropicalWeightTpl<Float> &w) {
|
||||
return w.Value();
|
||||
}
|
||||
|
||||
} // namespace fst
|
||||
|
||||
#endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_
|
@ -0,0 +1,798 @@
|
||||
// fstext/pre-determinize-inl.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
|
||||
#define KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
/* Do not include this file directly. It is an implementation file included by
|
||||
* PreDeterminize.h */
|
||||
|
||||
/*
|
||||
Predeterminization
|
||||
|
||||
This is a function that makes an FST compactly determinizable by inserting
|
||||
symbols on the input side as necessary for disambiguation. Note that we do
|
||||
not treat epsilon as a real symbol when measuring determinizability in this
|
||||
sense. The extra symbols are added to the vocabulary, on the input side;
|
||||
these are of the form (prefix)1, (prefix)2, and so on without limit, where
|
||||
(prefix) is some prefix the user provides, e.g. '#' (the function checks that
|
||||
this will not lead to conflicts with symbols already in the FST). The
|
||||
function tells us how many such symbols it created.
|
||||
|
||||
Note that there is a paper "Generalized optimization algorithm for speech
|
||||
recognition transducers" by Allauzen and Mohri, that deals with a similar
|
||||
issue, but this is a very different algorithm that only aims to ensure
|
||||
determinizability, but not *compact* determinizability.
|
||||
|
||||
Our algorithm is slightly heuristic, and probably not optimal, but does
|
||||
ensure that the output is compactly determinizable, possibly at the expense of
|
||||
inserting unnecessary symbols. We considered more sophisticated algorithms,
|
||||
but these were extremely complicated and would give the same output for the
|
||||
kinds of inputs that we envisage.
|
||||
|
||||
Suppose the input FST is T. We want to ensure that in det(T), if we consider
|
||||
the states of det(T) as weighted subsets of states of T, each state of T only
|
||||
appears once in any given subset. This ensures that det(T) is no larger than
|
||||
T in an appropriate sense. The way we do this is as follows. We identify all
|
||||
states in T that have multiple input transitions (counting "being an initial
|
||||
state" as an input transition). Let's call these "problematic" states. For a
|
||||
problematic state p we stipulate that it can never appear in any state of
|
||||
det(T) unless that state equals (p, \bar{1}) [i.e. p, unweighted]. In order
|
||||
to ensure this, we insert input symbols on the transitions to these
|
||||
problematic states (this may necessitate adding extra states).
|
||||
We also stipulate that the path through det(T) should always be sufficient
|
||||
to tell us the path through T (and we insert extra symbols sufficient to make
|
||||
this so). This is to simplify the algorithm, so that we don't have to
|
||||
consider the output symbols or weights when predeterminizing.
|
||||
|
||||
The algorithm is as follows.
|
||||
|
||||
(A) Definitions
|
||||
|
||||
(i) Define a *problematic state* as a state that either has multiple
|
||||
input transitions, or is an initial state and has at least one input
|
||||
transition.
|
||||
|
||||
(ii) For an arc a, define:
|
||||
i[a] = input symbol on a
|
||||
o[a] = output symbol on a
|
||||
n[a] = dest-state of a
|
||||
p[a] = origin-state of a
|
||||
|
||||
For a state q, define
|
||||
E[q] = set of transitions leaving q.
|
||||
For a set of states Q, define
|
||||
E[Q] = set of transitions leaving some q in Q
|
||||
|
||||
(iii) For a state s, define Closure(s) as the union of state s, and all
|
||||
states t that are reachable via sequences of arcs a such that i[a]=epsilon and
|
||||
n[a] is not problematic.
|
||||
|
||||
For a set of states S, define Closure(S) as the union of the closures
|
||||
of states s in S.
|
||||
|
||||
(B) Inputs and outputs.
|
||||
|
||||
(i) Inputs and preconditions. Input is an FST, which should have a symbol
|
||||
table compiled into it, and a prefix (e.g. #) for symbols to be added. We
|
||||
check that the input FST is trim, and that it does not have any symbols that
|
||||
appear on its arcs, that are equal to the prefix followed by digits.
|
||||
|
||||
(ii) Outputs: The algorithm modifies the FST that is given to it, and
|
||||
returns the number of the highest numbered "extra symbol" inserted. The extra
|
||||
symbols are numbered #1, #2 and so on without limit (as integers). They are
|
||||
inserted into the symbol table in a sequential way by calling AvailableKey()
|
||||
for each in turn (this is stipulated in case we need to keep other
|
||||
symbol tables in sync).
|
||||
|
||||
(C) Sub-algorithm: Closure(S). This requires the array p(s), defined
|
||||
below, which is true if s is problematic. This also requires, for efficiency,
|
||||
that the arcs be sorted on input label. Input: a set of states S. [plus, the
|
||||
fst and the array p]. Output: a set of states T. Algorithm: set T <-- S, Q <--
|
||||
S. while Q is nonempty: pop a state s from Q. for each transition a from state
|
||||
s with epsilon on the input label [we can find these efficiently using the
|
||||
sorting on arcs]: If p(n[a]) is false and n[a] is not in T: Insert n[a] into
|
||||
T. Add n[a] to Q. return T.
|
||||
|
||||
|
||||
(D) Main algorithm.
|
||||
|
||||
|
||||
(i) (a) Check preconditions (FST is trim)
|
||||
(b) Make sure there is just one final state (insert epsilon
|
||||
transitions as necessary). (c) Sort arcs on input label (so epsilon arcs are
|
||||
at the start of arc lists).
|
||||
|
||||
|
||||
(ii) Work out the set of problematic states by constructing a boolean
|
||||
array indexed by states, i.e. p(s) which is true if the state is problematic.
|
||||
We can do this by constructing an array t(s) to store the number of
|
||||
transitions into each state [adding one for the initial state], and then
|
||||
setting p(s) = true if t(s) > 1.
|
||||
|
||||
Also create a boolean array d(s), defined for states, and set d(s) =
|
||||
false. This array is purely for sanity-checking that we are processing each
|
||||
state exactly once.
|
||||
|
||||
(iii) Set up an array of integers m(a), indexed by arcs (how exactly we
|
||||
store these is implementation-dependent, but this will probably be a hash from
|
||||
(state, arc-index) to integers. m(a) will store the extra symbol, if any, to
|
||||
be added to that arc (or -1 if no such symbol; we can also simply have the arc
|
||||
not present in the hash). The initial value of m(a) is -1 (if array), or
|
||||
undefined (if hash).
|
||||
|
||||
(iv) Initialize a set of sets-of-states S, and a queue of pairs Q, as
|
||||
follows. The pairs in Q are a pair of (set-of-states, integer), where the
|
||||
integer is the number of "special symbols" already used up for that state.
|
||||
|
||||
Note that we use a special indexing for the sets in both S and Q,
|
||||
rather than using std::set. We use a sorted vector of StateId's. And in S,
|
||||
we index them by the lowest-numbered state-id. Because each state is supposed
|
||||
to only ever be a member of one set, if there is an attempt to add another,
|
||||
different set with the same lowest-numbered state-id, we detect an error.
|
||||
|
||||
Let I be the single initial state (OpenFST only supports one).
|
||||
We set:
|
||||
S = { Closure(I) }
|
||||
Push (Closure(I), 0) onto Q.
|
||||
Then for each state s such that p(s) = true, and s is not an initial
|
||||
state: S <-- S u { Closure(s) } Push (Closure(s), 0) onto Q.
|
||||
|
||||
(v) While Q is nonempty:
|
||||
|
||||
(a) Pop pair (A, n) from Q (queue discipline is arbitrary).
|
||||
|
||||
(b) For each state s in A, check that d(s) is false, and set d(s) to
|
||||
true. This is for sanity checking only.
|
||||
|
||||
(c)
|
||||
Let S_\eps be the set of epsilon-transitions from members of A to
|
||||
problematic states (i.e. S_\eps = \{ a \in E[A]: i[a]=\epsilon, p(n[a]) = true
|
||||
\}).
|
||||
|
||||
Next, we will define, for each t \neq \epsilon, S_t as the set of
|
||||
transitions from some state s in S with t as the input label,
|
||||
i.e.: S_t = \{ a \in E[A]: i[a] = t \} We further define T_t and U_t as the
|
||||
subsets of S where the destination state is problematic and non-problematic
|
||||
respectively, i.e: T_t = \{ a \in E[A]: i[a] = t, p(n[a]) = true \} U_t = \{ a
|
||||
\in E[A]: i[a] = t, p(n[a]) = false \}
|
||||
|
||||
The easiest way to obtain these sets is probably to have a hash
|
||||
indexed by t that maps to a list of pairs (state, arc-offset) that stores S_t.
|
||||
From this we can work out the sizes of T_t and U_t on the fly.
|
||||
|
||||
(d)
|
||||
for each transition a in S_\eps:
|
||||
m(a) <-- n # Will put symbol n on this transition.
|
||||
n <-- n+1 # Note, same n as in pair (A, n)
|
||||
|
||||
(e)
|
||||
next,
|
||||
for each t\neq epsilon s.t. S_t is nonempty,
|
||||
|
||||
if |S_t| > 1 #if-statement is because if |S_t|=|T_t|=1, no need
|
||||
for prefix. k = 0 for each transition a in T_t: set m(a) to k. set k = k+1
|
||||
|
||||
if |U_t| > 0
|
||||
Let V_t be the set of destination-states of arcs in U_t.
|
||||
if Closure(V_t) is not in S:
|
||||
insert Closure(V_t) into S, and add the pair (Closure(V_t),
|
||||
k) to Q.
|
||||
|
||||
(vi) Check that for each state in the FST, d(s) = true.
|
||||
|
||||
(vii) Let n = max_a m(a). This is the highest-numbered extra symbol
|
||||
(extra symbols start from zero, in this numbering which doesn't correspond to
|
||||
the symbol-table numbering). Here we add n+1 extra symbols to the symbol
|
||||
table and store the mappings from 0, 1, ... n to the symbol-id.
|
||||
|
||||
(viii) Set up a hash h from (state, int) to (state-id) such that
|
||||
t = h(s, k)
|
||||
will be the state-id of a newly-created state that has a transition
|
||||
to state s with input-label #k.
|
||||
|
||||
(ix) For each arc a such that m(a) != 0:
|
||||
If i[a] = epsilon (the input label is epsilon):
|
||||
Change i[a] to #m(a). [i.e. prefix then digit m(a)]
|
||||
Otherwise:
|
||||
If t = h(n[a], m(a)) is not defined [where n[a] is the
|
||||
dest-state]: create a new state t with a transition to n[a], with input-label
|
||||
#m(a) and no output-label or weight. Set h(n[a], m(a)) = t. Change n[a] to
|
||||
h(n[a], m(a)).
|
||||
|
||||
|
||||
*/
|
||||
namespace fst {
|
||||
|
||||
namespace pre_determinize_helpers {
|
||||
|
||||
// make it inline to avoid having to put it in a .cc file which most functions
|
||||
// here could not go in.
|
||||
inline bool HasBannedPrefixPlusDigits(SymbolTable *symTable, std::string prefix,
|
||||
std::string *bad_sym) {
|
||||
// returns true if the symbol table contains any string consisting of this
|
||||
// (possibly empty) prefix followed by a nonempty sequence of digits (0 to 9).
|
||||
// requires symTable to be non-NULL.
|
||||
// if bad_sym != NULL, puts the first bad symbol it finds in *bad_sym.
|
||||
assert(symTable != NULL);
|
||||
const char *prefix_ptr = prefix.c_str();
|
||||
size_t prefix_len =
|
||||
strlen(prefix_ptr); // allowed to be zero but not encouraged.
|
||||
for (SymbolTableIterator siter(*symTable); !siter.Done(); siter.Next()) {
|
||||
const std::string &sym = siter.Symbol();
|
||||
if (!strncmp(prefix_ptr, sym.c_str(), prefix_len)) { // has prefix.
|
||||
if (isdigit(sym[prefix_len])) { // we don't allow prefix followed by a
|
||||
// digit, as a symbol.
|
||||
// Has at least one digit.
|
||||
size_t pos;
|
||||
for (pos = prefix_len; sym[pos] != '\0'; pos++)
|
||||
if (!isdigit(sym[pos])) break;
|
||||
if (sym[pos] == '\0') { // All remaining characters were digits.
|
||||
if (bad_sym != NULL) *bad_sym = sym;
|
||||
return true;
|
||||
}
|
||||
} // else OK because prefix was followed by '\0' or a non-digit.
|
||||
}
|
||||
}
|
||||
return false; // doesn't have banned symbol.
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void CopySetToVector(const std::set<T> s, std::vector<T> *v) {
|
||||
// adds members of s to v, in sorted order from lowest to highest
|
||||
// (because the set was in sorted order).
|
||||
assert(v != NULL);
|
||||
v->resize(s.size());
|
||||
typename std::set<T>::const_iterator siter = s.begin();
|
||||
typename std::vector<T>::iterator viter = v->begin();
|
||||
for (; siter != s.end(); ++siter, ++viter) {
|
||||
assert(viter != v->end());
|
||||
*viter = *siter;
|
||||
}
|
||||
}
|
||||
|
||||
// Warning. This function calls 'new'.
|
||||
template <class T>
|
||||
std::vector<T> *InsertMember(const std::vector<T> m,
|
||||
std::vector<std::vector<T> *> *S) {
|
||||
assert(m.size() > 0);
|
||||
T idx = m[0];
|
||||
assert(idx >= (T)0 && idx < (T)S->size());
|
||||
if ((*S)[idx] != NULL) {
|
||||
assert(*((*S)[idx]) == m);
|
||||
// The vectors should be the same. Otherwise this is a bug in the
|
||||
// algorithm. It could either be a programming error or a deeper conceptual
|
||||
// bug.
|
||||
return NULL; // nothing was inserted.
|
||||
} else {
|
||||
std::vector<T> *ret = (*S)[idx] = new std::vector<T>(m); // New copy of m.
|
||||
return ret; // was inserted.
|
||||
}
|
||||
}
|
||||
|
||||
// See definition of Closure(S) in item A(iii) in the comment above. it's the
|
||||
// set of states that are reachable from S via sequences of arcs a such that
|
||||
// i[a]=epsilon and n[a] is not problematic. We assume that the fst is sorted
|
||||
// on input label (so epsilon arcs first) The algorithm is described in section
|
||||
// (C) above. We use the same variable for S and T.
|
||||
template <class Arc>
|
||||
void Closure(MutableFst<Arc> *fst, std::set<typename Arc::StateId> *S,
|
||||
const std::vector<bool> &pVec) {
|
||||
typedef typename Arc::StateId StateId;
|
||||
std::vector<StateId> Q;
|
||||
CopySetToVector(*S, &Q);
|
||||
while (Q.size() != 0) {
|
||||
StateId s = Q.back();
|
||||
Q.pop_back();
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
|
||||
aiter.Next()) {
|
||||
const Arc &arc = aiter.Value();
|
||||
if (arc.ilabel != 0)
|
||||
break; // Break from the loop: due to sorting there will be no
|
||||
// more transitions with epsilons as input labels.
|
||||
if (!pVec[arc.nextstate]) { // Next state is not problematic -> we can
|
||||
// use this transition.
|
||||
std::pair<typename std::set<StateId>::iterator, bool> p =
|
||||
S->insert(arc.nextstate);
|
||||
if (p.second) { // True means: was inserted into S (wasn't already
|
||||
// there).
|
||||
Q.push_back(arc.nextstate);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end function Closure.
|
||||
|
||||
} // end namespace pre_determinize_helpers.
|
||||
|
||||
template <class Arc, class Int>
|
||||
void PreDeterminize(MutableFst<Arc> *fst, typename Arc::Label first_new_sym,
|
||||
std::vector<Int> *symsOut) {
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef size_t ArcId; // Our own typedef, not standard OpenFst. Use size_t
|
||||
// for compatibility with argument of ArcIterator::Seek().
|
||||
typedef typename Arc::Weight Weight;
|
||||
assert(first_new_sym > 0);
|
||||
assert(fst != NULL);
|
||||
if (fst->Start() == kNoStateId) return; // for empty FST, nothing to do.
|
||||
assert(symsOut != NULL &&
|
||||
symsOut->size() == 0); // we will output the symbols we add into this.
|
||||
|
||||
{ // (D)(i)(a): check is trim (i.e. connected, in OpenFST parlance).
|
||||
KALDI_VLOG(2) << "PreDeterminize: Checking FST properties";
|
||||
uint64 props = fst->Properties(
|
||||
kAccessible | kCoAccessible,
|
||||
true); // true-> computes properties if unknown at time when called.
|
||||
if (props !=
|
||||
(kAccessible | kCoAccessible)) { // All states are not both accessible
|
||||
// and co-accessible...
|
||||
KALDI_ERR << "PreDeterminize: FST is not trim";
|
||||
}
|
||||
}
|
||||
|
||||
{ // (D)(i)(b): make single final state.
|
||||
KALDI_VLOG(2) << "PreDeterminize: creating single final state";
|
||||
CreateSuperFinal(fst);
|
||||
}
|
||||
|
||||
{ // (D)(i)(c): sort arcs on input.
|
||||
KALDI_VLOG(2) << "PreDeterminize: sorting arcs on input";
|
||||
ILabelCompare<Arc> icomp;
|
||||
ArcSort(fst, icomp);
|
||||
}
|
||||
|
||||
StateId n_states = 0,
|
||||
max_state =
|
||||
0; // Compute n_states, max_state = highest-numbered state.
|
||||
{ // compute nStates, maxStates.
|
||||
for (StateIterator<MutableFst<Arc> > iter(*fst); !iter.Done();
|
||||
iter.Next()) {
|
||||
StateId state = iter.Value();
|
||||
assert(state >= 0);
|
||||
n_states++;
|
||||
if (state > max_state) max_state = state;
|
||||
}
|
||||
KALDI_VLOG(2) << "PreDeterminize: n_states = " << (n_states)
|
||||
<< ", max_state =" << (max_state);
|
||||
}
|
||||
|
||||
std::vector<bool> p_vec(max_state + 1, false); // compute this next.
|
||||
{ // D(ii): computing the array p. ["problematic states, i.e. states with >1
|
||||
// input transition,
|
||||
// counting being the initial state as an input transition"].
|
||||
std::vector<bool> seen_vec(
|
||||
max_state + 1,
|
||||
false); // rather than counting incoming transitions we just have a
|
||||
// bool that says we saw at least one.
|
||||
|
||||
seen_vec[fst->Start()] = true;
|
||||
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||||
siter.Next()) {
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst, siter.Value());
|
||||
!aiter.Done(); aiter.Next()) {
|
||||
const Arc &arc = aiter.Value();
|
||||
assert(arc.nextstate >= 0 && arc.nextstate < max_state + 1);
|
||||
if (seen_vec[arc.nextstate])
|
||||
p_vec[arc.nextstate] =
|
||||
true; // now have >1 transition in, so problematic.
|
||||
else
|
||||
seen_vec[arc.nextstate] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// D(iii): set up m(a)
|
||||
std::map<std::pair<StateId, ArcId>, size_t> m_map;
|
||||
// This is the array m, indexed by arcs. It maps to the index of the symbol
|
||||
// we add.
|
||||
|
||||
// WARNING: we should be sure to clean up this memory before exiting. Do not
|
||||
// return or throw an exception from this function, later than this point,
|
||||
// without cleaning up! Note that the vectors are shared between Q and S (they
|
||||
// "belong to" S.
|
||||
std::vector<std::vector<StateId> *> S(max_state + 1,
|
||||
(std::vector<StateId> *)(void *)0);
|
||||
std::vector<std::pair<std::vector<StateId> *, size_t> > Q;
|
||||
|
||||
// D(iv): initialize S and Q.
|
||||
{
|
||||
std::vector<StateId>
|
||||
all_seed_states; // all "problematic" states, plus initial state (if
|
||||
// not problematic).
|
||||
if (!p_vec[fst->Start()]) all_seed_states.push_back(fst->Start());
|
||||
for (StateId s = 0; s <= max_state; s++)
|
||||
if (p_vec[s]) all_seed_states.push_back(s);
|
||||
|
||||
for (size_t idx = 0; idx < all_seed_states.size(); idx++) {
|
||||
StateId s = all_seed_states[idx];
|
||||
std::set<StateId> closure_s;
|
||||
closure_s.insert(s); // insert "seed" state.
|
||||
pre_determinize_helpers::Closure(
|
||||
fst, &closure_s,
|
||||
p_vec); // follow epsilons to non-problematic states.
|
||||
// Closure in this case whis will usually not add anything, for typical
|
||||
// topologies in speech
|
||||
std::vector<StateId> closure_s_vec;
|
||||
pre_determinize_helpers::CopySetToVector(closure_s, &closure_s_vec);
|
||||
KALDI_ASSERT(closure_s_vec.size() != 0);
|
||||
std::vector<StateId> *ptr =
|
||||
pre_determinize_helpers::InsertMember(closure_s_vec, &S);
|
||||
KALDI_ASSERT(ptr != NULL); // Or conceptual bug or programming error.
|
||||
Q.push_back(std::pair<std::vector<StateId> *, size_t>(ptr, 0));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<bool> d_vec(max_state + 1,
|
||||
false); // "done vector". Purely for debugging.
|
||||
|
||||
size_t num_extra_det_states = 0;
|
||||
|
||||
// (D)(v)
|
||||
while (Q.size() != 0) {
|
||||
// (D)(v)(a)
|
||||
std::pair<std::vector<StateId> *, size_t> cur_pair(Q.back());
|
||||
Q.pop_back();
|
||||
const std::vector<StateId> &A(*cur_pair.first);
|
||||
size_t n = cur_pair.second; // next special symbol to add.
|
||||
|
||||
// (D)(v)(b)
|
||||
for (size_t idx = 0; idx < A.size(); idx++) {
|
||||
assert(d_vec[A[idx]] == false &&
|
||||
"This state has been seen before. Algorithm error.");
|
||||
d_vec[A[idx]] = true;
|
||||
}
|
||||
|
||||
// From here is (D)(v)(c). We work out S_\eps and S_t (for t\neq eps)
|
||||
// simultaneously at first.
|
||||
std::map<Label, std::set<std::pair<std::pair<StateId, ArcId>, StateId> > >
|
||||
arc_hash;
|
||||
// arc_hash is a hash with info of all arcs from states in the set A to
|
||||
// non-problematic states.
|
||||
// It is a map from ilabel to pair(pair(start-state, arc-offset),
|
||||
// end-state). Here, arc-offset reflects the order in which we accessed the
|
||||
// arc using the ArcIterator (zero for the first arc).
|
||||
|
||||
{ // This block sets up arc_hash
|
||||
for (size_t idx = 0; idx < A.size(); idx++) {
|
||||
StateId s = A[idx];
|
||||
assert(s >= 0 && s <= max_state);
|
||||
ArcId arc_id = 0;
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
|
||||
aiter.Next(), ++arc_id) {
|
||||
const Arc &arc = aiter.Value();
|
||||
|
||||
std::pair<std::pair<StateId, ArcId>, StateId> this_pair(
|
||||
std::pair<StateId, ArcId>(s, arc_id), arc.nextstate);
|
||||
bool inserted = (arc_hash[arc.ilabel].insert(this_pair)).second;
|
||||
assert(inserted); // Otherwise we had a duplicate.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (D)(v)(d)
|
||||
if (arc_hash.count(0) == 1) { // We have epsilon transitions out.
|
||||
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &eps_set =
|
||||
arc_hash[0];
|
||||
typedef typename std::set<
|
||||
std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t;
|
||||
for (set_iter_t siter = eps_set.begin(); siter != eps_set.end();
|
||||
++siter) {
|
||||
const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr = *siter;
|
||||
if (p_vec[this_pr.second]) { // Eps-transition to problematic state.
|
||||
assert(m_map.count(this_pr.first) == 0);
|
||||
m_map[this_pr.first] = n;
|
||||
n++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (D)(v)(e)
|
||||
{
|
||||
typedef typename std::map<
|
||||
Label,
|
||||
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > >::iterator
|
||||
map_iter_t;
|
||||
typedef typename std::set<
|
||||
std::pair<std::pair<StateId, ArcId>, StateId> >::iterator set_iter_t2;
|
||||
for (map_iter_t miter = arc_hash.begin(); miter != arc_hash.end();
|
||||
++miter) {
|
||||
Label t = miter->first;
|
||||
std::set<std::pair<std::pair<StateId, ArcId>, StateId> > &S_t =
|
||||
miter->second;
|
||||
if (t != 0) { // For t != epsilon,
|
||||
std::set<StateId> V_t; // set of destination non-problem states. Will
|
||||
// create this set now.
|
||||
|
||||
// exists_noproblem is true iff |U_t| > 0.
|
||||
size_t k = 0;
|
||||
|
||||
// First loop "for each transition a in T_t" (i.e. transitions to
|
||||
// problematic states) The if-statement if (|S_t|>1) is pushed inside
|
||||
// the loop, as the loop also computes the set V_t.
|
||||
for (set_iter_t2 siter = S_t.begin(); siter != S_t.end(); ++siter) {
|
||||
const std::pair<std::pair<StateId, ArcId>, StateId> &this_pr =
|
||||
*siter;
|
||||
if (p_vec[this_pr.second]) { // only consider problematic states
|
||||
// (just set T_t)
|
||||
if (S_t.size() >
|
||||
1) { // This is where we pushed the if-statement in.
|
||||
assert(m_map.count(this_pr.first) == 0);
|
||||
m_map[this_pr.first] = k;
|
||||
k++;
|
||||
num_extra_det_states++;
|
||||
}
|
||||
} else { // Create the set V_t.
|
||||
V_t.insert(this_pr.second);
|
||||
}
|
||||
}
|
||||
if (V_t.size() != 0) {
|
||||
pre_determinize_helpers::Closure(
|
||||
fst, &V_t,
|
||||
p_vec); // follow epsilons to non-problematic states.
|
||||
std::vector<StateId> closure_V_t_vec;
|
||||
pre_determinize_helpers::CopySetToVector(V_t, &closure_V_t_vec);
|
||||
std::vector<StateId> *ptr =
|
||||
pre_determinize_helpers::InsertMember(closure_V_t_vec, &S);
|
||||
if (ptr != NULL) { // was inserted.
|
||||
Q.push_back(std::pair<std::vector<StateId> *, size_t>(ptr, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end while (Q.size() != 0)
|
||||
|
||||
{ // (D)(vi): Check that for each state in the FST, d(s) = true.
|
||||
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||||
siter.Next()) {
|
||||
StateId val = siter.Value();
|
||||
assert(d_vec[val] == true);
|
||||
}
|
||||
}
|
||||
|
||||
{ // (D)(vii): compute symbol-table ID's.
|
||||
// sets up symsOut array.
|
||||
int64 n = -1;
|
||||
for (typename std::map<std::pair<StateId, ArcId>, size_t>::iterator m_iter =
|
||||
m_map.begin();
|
||||
m_iter != m_map.end(); ++m_iter) {
|
||||
n = std::max(n,
|
||||
static_cast<int64>(
|
||||
m_iter->second)); // m_iter->second is of type size_t.
|
||||
}
|
||||
// At this point n is the highest symbol-id (type size_t) of symbols we must
|
||||
// add.
|
||||
n++; // This is now the number of symbols we must add.
|
||||
for (size_t i = 0; static_cast<int64>(i) < n; i++)
|
||||
symsOut->push_back(first_new_sym + i);
|
||||
}
|
||||
|
||||
// (D)(viii): set up hash.
|
||||
std::map<std::pair<StateId, size_t>, StateId> h_map;
|
||||
|
||||
{ // D(ix): add extra symbols! This is where the work gets done.
|
||||
// Core part of this is below, search for (*)
|
||||
size_t n_states_added = 0;
|
||||
|
||||
for (typename std::map<std::pair<StateId, ArcId>, size_t>::iterator m_iter =
|
||||
m_map.begin();
|
||||
m_iter != m_map.end(); ++m_iter) {
|
||||
StateId state = m_iter->first.first;
|
||||
ArcId arcpos = m_iter->first.second;
|
||||
size_t m_a = m_iter->second;
|
||||
|
||||
MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
|
||||
aiter.Seek(arcpos);
|
||||
Arc arc = aiter.Value();
|
||||
|
||||
// (*) core part here.
|
||||
if (arc.ilabel == 0) {
|
||||
arc.ilabel = (*symsOut)[m_a];
|
||||
} else {
|
||||
std::pair<StateId, size_t> pr(arc.nextstate, m_a);
|
||||
if (!h_map.count(pr)) {
|
||||
n_states_added++;
|
||||
StateId newstate = fst->AddState();
|
||||
assert(newstate >= 0);
|
||||
Arc new_arc((*symsOut)[m_a], (Label)0, Weight::One(), arc.nextstate);
|
||||
fst->AddArc(newstate, new_arc);
|
||||
h_map[pr] = newstate;
|
||||
}
|
||||
arc.nextstate = h_map[pr];
|
||||
}
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
|
||||
KALDI_VLOG(2) << "Added " << (n_states_added)
|
||||
<< " new states and added/changed " << (m_map.size())
|
||||
<< " arcs";
|
||||
}
|
||||
// Now free up memory.
|
||||
for (size_t i = 0; i < S.size(); i++) delete S[i];
|
||||
} // end function PreDeterminize
|
||||
|
||||
template <class Label>
|
||||
void CreateNewSymbols(SymbolTable *input_sym_table, int nSym,
|
||||
std::string prefix, std::vector<Label> *symsOut) {
|
||||
// Creates nSym new symbols named (prefix)0, (prefix)1 and so on.
|
||||
// Crashes if it cannot create them because one or more of them were in the
|
||||
// symbol table already.
|
||||
assert(symsOut && symsOut->size() == 0);
|
||||
for (int i = 0; i < nSym; i++) {
|
||||
std::stringstream ss;
|
||||
ss << prefix << i;
|
||||
std::string str = ss.str();
|
||||
if (input_sym_table->Find(str) != -1) { // should not be present.
|
||||
}
|
||||
assert(symsOut);
|
||||
symsOut->push_back((Label)input_sym_table->AddSymbol(str));
|
||||
}
|
||||
}
|
||||
|
||||
// see pre-determinize.h for documentation.
|
||||
template <class Arc>
|
||||
void AddSelfLoops(MutableFst<Arc> *fst,
|
||||
const std::vector<typename Arc::Label> &isyms,
|
||||
const std::vector<typename Arc::Label> &osyms) {
|
||||
assert(fst != NULL);
|
||||
assert(isyms.size() == osyms.size());
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef typename Arc::Weight Weight;
|
||||
size_t n = isyms.size();
|
||||
if (n == 0) return; // Nothing to do.
|
||||
|
||||
// {
|
||||
// the following declarations and statements are for quick detection of these
|
||||
// symbols, which is purely for debugging/checking purposes.
|
||||
Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
|
||||
isyms_max = *std::max_element(isyms.begin(), isyms.end()),
|
||||
osyms_min = *std::min_element(osyms.begin(), osyms.end()),
|
||||
osyms_max = *std::max_element(osyms.begin(), osyms.end());
|
||||
std::set<Label> isyms_set, osyms_set;
|
||||
for (size_t i = 0; i < isyms.size(); i++) {
|
||||
assert(isyms[i] > 0 &&
|
||||
osyms[i] > 0); // should not have epsilon or invalid symbols.
|
||||
isyms_set.insert(isyms[i]);
|
||||
osyms_set.insert(osyms[i]);
|
||||
}
|
||||
assert(isyms_set.size() == n && osyms_set.size() == n);
|
||||
// } end block.
|
||||
|
||||
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||||
siter.Next()) {
|
||||
StateId state = siter.Value();
|
||||
bool this_state_needs_self_loops = (fst->Final(state) != Weight::Zero());
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); !aiter.Done();
|
||||
aiter.Next()) {
|
||||
const Arc &arc = aiter.Value();
|
||||
// If one of the following asserts fails, it means that the input FST
|
||||
// already had the symbols we are inserting. This is contrary to the
|
||||
// preconditions of this algorithm.
|
||||
assert(!(arc.ilabel >= isyms_min && arc.ilabel <= isyms_max &&
|
||||
isyms_set.count(arc.ilabel) != 0));
|
||||
assert(!(arc.olabel >= osyms_min && arc.olabel <= osyms_max &&
|
||||
osyms_set.count(arc.olabel) != 0));
|
||||
if (arc.olabel != 0) // Has non-epsilon output label -> need self loops.
|
||||
this_state_needs_self_loops = true;
|
||||
}
|
||||
if (this_state_needs_self_loops) {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
Arc arc;
|
||||
arc.ilabel = isyms[i];
|
||||
arc.olabel = osyms[i];
|
||||
arc.weight = Weight::One();
|
||||
arc.nextstate = state;
|
||||
fst->AddArc(state, arc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Arc>
|
||||
int64 DeleteISymbols(MutableFst<Arc> *fst,
|
||||
std::vector<typename Arc::Label> isyms) {
|
||||
// We could do this using the Mapper concept, but this is much easier to
|
||||
// understand.
|
||||
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::StateId StateId;
|
||||
|
||||
int64 num_deleted = 0;
|
||||
|
||||
if (isyms.size() == 0) return 0;
|
||||
Label isyms_min = *std::min_element(isyms.begin(), isyms.end()),
|
||||
isyms_max = *std::max_element(isyms.begin(), isyms.end());
|
||||
bool isyms_consecutive =
|
||||
(isyms_max + 1 - isyms_min == static_cast<Label>(isyms.size()));
|
||||
std::set<Label> isyms_set;
|
||||
if (!isyms_consecutive) {
|
||||
for (size_t i = 0; i < isyms.size(); i++) isyms_set.insert(isyms[i]);
|
||||
}
|
||||
|
||||
for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
|
||||
siter.Next()) {
|
||||
StateId state = siter.Value();
|
||||
for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state); !aiter.Done();
|
||||
aiter.Next()) {
|
||||
const Arc &arc = aiter.Value();
|
||||
if (arc.ilabel >= isyms_min && arc.ilabel <= isyms_max) {
|
||||
if (isyms_consecutive || isyms_set.count(arc.ilabel) != 0) {
|
||||
num_deleted++;
|
||||
Arc mod_arc(arc);
|
||||
mod_arc.ilabel = 0; // change label to epsilon.
|
||||
aiter.SetValue(mod_arc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_deleted;
|
||||
}
|
||||
|
||||
template <class Arc>
|
||||
typename Arc::StateId CreateSuperFinal(MutableFst<Arc> *fst) {
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef typename Arc::Weight Weight;
|
||||
assert(fst != NULL);
|
||||
StateId num_states = fst->NumStates();
|
||||
StateId num_final = 0;
|
||||
std::vector<StateId> final_states;
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
if (fst->Final(s) != Weight::Zero()) {
|
||||
num_final++;
|
||||
final_states.push_back(s);
|
||||
}
|
||||
}
|
||||
if (final_states.size() == 1) {
|
||||
if (fst->Final(final_states[0]) == Weight::One()) {
|
||||
ArcIterator<MutableFst<Arc> > iter(*fst, final_states[0]);
|
||||
if (iter.Done()) {
|
||||
// We already have a final state w/ no transitions out and unit weight.
|
||||
// So we're done.
|
||||
return final_states[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StateId final_state = fst->AddState();
|
||||
fst->SetFinal(final_state, Weight::One());
|
||||
for (size_t idx = 0; idx < final_states.size(); idx++) {
|
||||
StateId s = final_states[idx];
|
||||
Weight weight = fst->Final(s);
|
||||
fst->SetFinal(s, Weight::Zero());
|
||||
Arc arc;
|
||||
arc.ilabel = 0;
|
||||
arc.olabel = 0;
|
||||
arc.nextstate = final_state;
|
||||
arc.weight = weight;
|
||||
fst->AddArc(s, arc);
|
||||
}
|
||||
return final_state;
|
||||
}
|
||||
|
||||
} // namespace fst
|
||||
|
||||
#endif // KALDI_FSTEXT_PRE_DETERMINIZE_INL_H_
|
@ -0,0 +1,98 @@
|
||||
// fstext/pre-determinize.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_PRE_DETERMINIZE_H_
|
||||
#define KALDI_FSTEXT_PRE_DETERMINIZE_H_
|
||||
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
|
||||
namespace fst {
|
||||
|
||||
/* PreDeterminize inserts extra symbols on the input side of an FST as necessary
|
||||
to ensure that, after epsilon removal, it will be compactly determinizable by
|
||||
the determinize* algorithm. By compactly determinizable we mean that no
|
||||
original FST state is represented in more than one determinized state).
|
||||
|
||||
Caution: this code is now only used in testing.
|
||||
|
||||
The new symbols start from the value "first_new_symbol", which should be
|
||||
higher than the largest-numbered symbol currently in the FST. The new
|
||||
symbols added are put in the array syms_out, which should be empty at start.
|
||||
*/
|
||||
|
||||
template <class Arc, class Int>
|
||||
void PreDeterminize(MutableFst<Arc> *fst, typename Arc::Label first_new_symbol,
|
||||
std::vector<Int> *syms_out);
|
||||
|
||||
/* CreateNewSymbols is a helper function used inside PreDeterminize, and is also
|
||||
useful when you need to add a number of extra symbols to a different
|
||||
vocabulary from the one modified by PreDeterminize. */
|
||||
|
||||
template <class Label>
|
||||
void CreateNewSymbols(SymbolTable *inputSymTable, int nSym, std::string prefix,
|
||||
std::vector<Label> *syms_out);
|
||||
|
||||
/** AddSelfLoops is a function you will probably want to use alongside
|
||||
PreDeterminize, to add self-loops to any FSTs that you compose on the left
|
||||
hand side of the one modified by PreDeterminize.
|
||||
|
||||
This function inserts loops with "special symbols" [e.g. \#0, \#1] into an
|
||||
FST. This is done at each final state and each state with non-epsilon output
|
||||
symbols on at least one arc out of it. This is to ensure that these symbols,
|
||||
when inserted into the input side of an FST we will compose with on the
|
||||
right, can "pass through" this FST.
|
||||
|
||||
At input, isyms and osyms must be vectors of the same size n, corresponding
|
||||
to symbols that currently do not exist in 'fst'. For each state in n that
|
||||
has non-epsilon symbols on the output side of arcs leaving it, or which is a
|
||||
final state, this function inserts n self-loops with unit weight and one of
|
||||
the n pairs of symbols on its input and output.
|
||||
*/
|
||||
template <class Arc>
|
||||
void AddSelfLoops(MutableFst<Arc> *fst,
|
||||
const std::vector<typename Arc::Label> &isyms,
|
||||
const std::vector<typename Arc::Label> &osyms);
|
||||
|
||||
/* DeleteSymbols replaces any instances of symbols in the vector symsIn,
|
||||
appearing on the input side, with epsilon. */
|
||||
/* It returns the number of instances of symbols deleted. */
|
||||
template <class Arc>
|
||||
int64 DeleteISymbols(MutableFst<Arc> *fst,
|
||||
std::vector<typename Arc::Label> symsIn);
|
||||
|
||||
/* CreateSuperFinal takes an FST, and creates an equivalent FST with a single
|
||||
final state with no transitions out and unit final weight, by inserting
|
||||
epsilon transitions as necessary. */
|
||||
template <class Arc>
|
||||
typename Arc::StateId CreateSuperFinal(MutableFst<Arc> *fst);
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
#include "fstext/pre-determinize-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_PRE_DETERMINIZE_H_
|
@ -0,0 +1,318 @@
|
||||
// fstext/remove-eps-local-inl.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2014 Johns Hopkins University (author: Daniel Povey
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|
||||
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace fst {
|
||||
|
||||
template <class Weight>
|
||||
struct ReweightPlusDefault {
|
||||
inline Weight operator()(const Weight &a, const Weight &b) {
|
||||
return Plus(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
struct ReweightPlusLogArc {
|
||||
inline TropicalWeight operator()(const TropicalWeight &a,
|
||||
const TropicalWeight &b) {
|
||||
LogWeight a_log(a.Value()), b_log(b.Value());
|
||||
return TropicalWeight(Plus(a_log, b_log).Value());
|
||||
}
|
||||
};
|
||||
|
||||
template <class Arc,
|
||||
class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
|
||||
class RemoveEpsLocalClass {
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::Weight Weight;
|
||||
|
||||
public:
|
||||
explicit RemoveEpsLocalClass(MutableFst<Arc> *fst) : fst_(fst) {
|
||||
if (fst_->Start() == kNoStateId) return; // empty.
|
||||
non_coacc_state_ = fst_->AddState();
|
||||
InitNumArcs();
|
||||
StateId num_states = fst_->NumStates();
|
||||
for (StateId s = 0; s < num_states; s++)
|
||||
for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
|
||||
assert(CheckNumArcs());
|
||||
Connect(fst); // remove inaccessible states.
|
||||
}
|
||||
|
||||
private:
|
||||
MutableFst<Arc> *fst_;
|
||||
StateId non_coacc_state_; // use this to delete arcs: make it nextstate
|
||||
std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
|
||||
// one if it's the start state.
|
||||
std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
|
||||
// plus one if it's a final state.
|
||||
ReweightPlus reweight_plus_;
|
||||
|
||||
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
|
||||
if (a.ilabel != 0 && b.ilabel != 0) return false;
|
||||
if (a.olabel != 0 && b.olabel != 0) return false;
|
||||
c->weight = Times(a.weight, b.weight);
|
||||
c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
|
||||
c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
|
||||
c->nextstate = b.nextstate;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool CanCombineFinal(const Arc &a, Weight final_prob,
|
||||
Weight *final_prob_out) {
|
||||
if (a.ilabel != 0 || a.olabel != 0) {
|
||||
return false;
|
||||
} else {
|
||||
*final_prob_out = Times(a.weight, final_prob);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void InitNumArcs() { // init num transitions in/out of each state.
|
||||
StateId num_states = fst_->NumStates();
|
||||
num_arcs_in_.resize(num_states);
|
||||
num_arcs_out_.resize(num_states);
|
||||
num_arcs_in_[fst_->Start()]++; // count start as trans in.
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
if (fst_->Final(s) != Weight::Zero())
|
||||
num_arcs_out_[s]++; // count final as transition.
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
|
||||
aiter.Next()) {
|
||||
num_arcs_in_[aiter.Value().nextstate]++;
|
||||
num_arcs_out_[s]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
|
||||
num_arcs_in_[fst_->Start()]--; // count start as trans in.
|
||||
StateId num_states = fst_->NumStates();
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
if (s == non_coacc_state_) continue;
|
||||
if (fst_->Final(s) != Weight::Zero())
|
||||
num_arcs_out_[s]--; // count final as transition.
|
||||
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
|
||||
aiter.Next()) {
|
||||
if (aiter.Value().nextstate == non_coacc_state_) continue;
|
||||
num_arcs_in_[aiter.Value().nextstate]--;
|
||||
num_arcs_out_[s]--;
|
||||
}
|
||||
}
|
||||
for (StateId s = 0; s < num_states; s++) {
|
||||
assert(num_arcs_in_[s] == 0);
|
||||
assert(num_arcs_out_[s] == 0);
|
||||
}
|
||||
return true; // always does this. so we can assert it w/o warnings.
|
||||
}
|
||||
|
||||
inline void GetArc(StateId s, size_t pos, Arc *arc) const {
|
||||
ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
|
||||
aiter.Seek(pos);
|
||||
*arc = aiter.Value();
|
||||
}
|
||||
|
||||
inline void SetArc(StateId s, size_t pos, const Arc &arc) {
|
||||
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
|
||||
aiter.Seek(pos);
|
||||
aiter.SetValue(arc);
|
||||
}
|
||||
|
||||
void Reweight(StateId s, size_t pos, Weight reweight) {
|
||||
// Reweight is called from RemoveEpsPattern1; it is a step we
|
||||
// do to preserve stochasticity. This function multiplies the
|
||||
// arc at (s, pos) by reweight and divides all the arcs [+final-prob]
|
||||
// out of the next state by the same. This is only valid if
|
||||
// the next state has only one arc in and is not the start state.
|
||||
assert(reweight != Weight::Zero());
|
||||
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
|
||||
aiter.Seek(pos);
|
||||
Arc arc = aiter.Value();
|
||||
assert(num_arcs_in_[arc.nextstate] == 1);
|
||||
arc.weight = Times(arc.weight, reweight);
|
||||
aiter.SetValue(arc);
|
||||
|
||||
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
|
||||
!aiter_next.Done(); aiter_next.Next()) {
|
||||
Arc nextarc = aiter_next.Value();
|
||||
if (nextarc.nextstate != non_coacc_state_) {
|
||||
nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
|
||||
aiter_next.SetValue(nextarc);
|
||||
}
|
||||
}
|
||||
Weight final = fst_->Final(arc.nextstate);
|
||||
if (final != Weight::Zero()) {
|
||||
fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveEpsPattern1 applies where this arc, which is not a
|
||||
// self-loop, enters a state which has only one input transition
|
||||
// [and is not the start state], and has multiple output
|
||||
// transitions [counting being the final-state as a final-transition].
|
||||
|
||||
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
|
||||
const StateId nextstate = arc.nextstate;
|
||||
Weight total_removed = Weight::Zero(),
|
||||
total_kept = Weight::Zero(); // totals out of nextstate.
|
||||
std::vector<Arc> arcs_to_add; // to add to state s.
|
||||
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
|
||||
!aiter_next.Done(); aiter_next.Next()) {
|
||||
Arc nextarc = aiter_next.Value();
|
||||
if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
|
||||
Arc combined;
|
||||
if (CanCombineArcs(arc, nextarc, &combined)) {
|
||||
total_removed = reweight_plus_(total_removed, nextarc.weight);
|
||||
num_arcs_out_[nextstate]--;
|
||||
num_arcs_in_[nextarc.nextstate]--;
|
||||
nextarc.nextstate = non_coacc_state_;
|
||||
aiter_next.SetValue(nextarc);
|
||||
arcs_to_add.push_back(combined);
|
||||
} else {
|
||||
total_kept = reweight_plus_(total_kept, nextarc.weight);
|
||||
}
|
||||
}
|
||||
|
||||
{ // now final-state.
|
||||
Weight next_final = fst_->Final(nextstate);
|
||||
if (next_final != Weight::Zero()) {
|
||||
Weight new_final;
|
||||
if (CanCombineFinal(arc, next_final, &new_final)) {
|
||||
total_removed = reweight_plus_(total_removed, next_final);
|
||||
if (fst_->Final(s) == Weight::Zero())
|
||||
num_arcs_out_[s]++; // final is counted as arc.
|
||||
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
|
||||
num_arcs_out_[nextstate]--;
|
||||
fst_->SetFinal(nextstate, Weight::Zero());
|
||||
} else {
|
||||
total_kept = reweight_plus_(total_kept, next_final);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (total_removed != Weight::Zero()) { // did something...
|
||||
if (total_kept == Weight::Zero()) { // removed everything: remove arc.
|
||||
num_arcs_out_[s]--;
|
||||
num_arcs_in_[arc.nextstate]--;
|
||||
arc.nextstate = non_coacc_state_;
|
||||
SetArc(s, pos, arc);
|
||||
} else {
|
||||
// Have to reweight.
|
||||
Weight total = reweight_plus_(total_removed, total_kept);
|
||||
Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
|
||||
Reweight(s, pos, reweight);
|
||||
}
|
||||
}
|
||||
// Now add the arcs we were going to add.
|
||||
for (size_t i = 0; i < arcs_to_add.size(); i++) {
|
||||
num_arcs_out_[s]++;
|
||||
num_arcs_in_[arcs_to_add[i].nextstate]++;
|
||||
fst_->AddArc(s, arcs_to_add[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
|
||||
// Pattern 2 is where "nextstate" has only one arc out, counting
|
||||
// being-the-final-state as an arc, but possibly multiple arcs in.
|
||||
// Also, nextstate != s.
|
||||
|
||||
const StateId nextstate = arc.nextstate;
|
||||
bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
|
||||
// we combine, can delete the corresponding out-arc/final-prob
|
||||
// of nextstate.
|
||||
bool delete_arc = false; // set to true if this arc to be deleted.
|
||||
|
||||
Weight next_final = fst_->Final(arc.nextstate);
|
||||
if (next_final !=
|
||||
Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
|
||||
Weight new_final;
|
||||
if (CanCombineFinal(arc, next_final, &new_final)) {
|
||||
if (fst_->Final(s) == Weight::Zero())
|
||||
num_arcs_out_[s]++; // final is counted as arc.
|
||||
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
|
||||
delete_arc = true; // will delete "arc".
|
||||
if (can_delete_next) {
|
||||
num_arcs_out_[nextstate]--;
|
||||
fst_->SetFinal(nextstate, Weight::Zero());
|
||||
}
|
||||
}
|
||||
} else { // has an arc but no final prob.
|
||||
MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
|
||||
assert(!aiter_next.Done());
|
||||
while (aiter_next.Value().nextstate == non_coacc_state_) {
|
||||
aiter_next.Next();
|
||||
assert(!aiter_next.Done());
|
||||
}
|
||||
// now aiter_next points to a real arc out of nextstate.
|
||||
Arc nextarc = aiter_next.Value();
|
||||
Arc combined;
|
||||
if (CanCombineArcs(arc, nextarc, &combined)) {
|
||||
delete_arc = true;
|
||||
if (can_delete_next) { // do it before we invalidate iterators
|
||||
num_arcs_out_[nextstate]--;
|
||||
num_arcs_in_[nextarc.nextstate]--;
|
||||
nextarc.nextstate = non_coacc_state_;
|
||||
aiter_next.SetValue(nextarc);
|
||||
}
|
||||
num_arcs_out_[s]++;
|
||||
num_arcs_in_[combined.nextstate]++;
|
||||
fst_->AddArc(s, combined);
|
||||
}
|
||||
}
|
||||
if (delete_arc) {
|
||||
num_arcs_out_[s]--;
|
||||
num_arcs_in_[nextstate]--;
|
||||
arc.nextstate = non_coacc_state_;
|
||||
SetArc(s, pos, arc);
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveEps(StateId s, size_t pos) {
|
||||
// Tries to do local epsilon-removal for arc sequences starting with this
|
||||
// arc
|
||||
Arc arc;
|
||||
GetArc(s, pos, &arc);
|
||||
StateId nextstate = arc.nextstate;
|
||||
if (nextstate == non_coacc_state_) return; // deleted arc.
|
||||
if (nextstate == s) return; // don't handle self-loops: too complex.
|
||||
|
||||
if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
|
||||
RemoveEpsPattern1(s, pos, arc);
|
||||
} else if (num_arcs_out_[nextstate] == 1) {
|
||||
RemoveEpsPattern2(s, pos, arc);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Arc>
|
||||
void RemoveEpsLocal(MutableFst<Arc> *fst) {
|
||||
RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
|
||||
}
|
||||
|
||||
void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
|
||||
// work gets done in initializer.
|
||||
RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
|
||||
}
|
||||
|
||||
} // end namespace fst.
|
||||
|
||||
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
|
@ -0,0 +1,57 @@
|
||||
// fstext/remove-eps-local.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
|
||||
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
|
||||
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST,
|
||||
/// using an algorithm that is guaranteed to never increase the number of arcs
|
||||
/// in the FST (and will also never increase the number of states). The
|
||||
/// algorithm is not optimal but is reasonably clever. It does not just remove
|
||||
/// epsilon arcs;it also combines pairs of input-epsilon and output-epsilon arcs
|
||||
/// into one.
|
||||
/// The algorithm preserves equivalence and stochasticity in the given semiring.
|
||||
/// If you want to preserve stochasticity in a different semiring (e.g. log),
|
||||
/// then use RemoveEpsLocalSpecial, which only works for StdArc but which
|
||||
/// preserves stochasticity, where possible (*) in the LogArc sense. The reason
|
||||
/// that we can't just cast to a different semiring is that in that case we
|
||||
/// would no longer be able to guarantee equivalence in the original semiring
|
||||
/// (this arises from what happens when we combine identical arcs).
|
||||
/// (*) by "where possible".. there are situations where we wouldn't be able to
|
||||
/// preserve stochasticity in the LogArc sense while maintaining equivalence in
|
||||
/// the StdArc sense, so in these situations we maintain equivalence.
|
||||
|
||||
template <class Arc>
|
||||
void RemoveEpsLocal(MutableFst<Arc> *fst);
|
||||
|
||||
/// As RemoveEpsLocal but takes care to preserve stochasticity
|
||||
/// when cast to LogArc.
|
||||
inline void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst);
|
||||
|
||||
} // namespace fst
|
||||
|
||||
#include "fstext/remove-eps-local-inl.h"
|
||||
|
||||
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
|
@ -0,0 +1,387 @@
|
||||
// fstext/table-matcher.h
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_FSTEXT_TABLE_MATCHER_H_
|
||||
#define KALDI_FSTEXT_TABLE_MATCHER_H_
|
||||
|
||||
#include <fst/fst-decl.h>
|
||||
#include <fst/fstlib.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace fst {
|
||||
|
||||
/// TableMatcher is a matcher specialized for the case where the output
|
||||
/// side of the left FST always has either all-epsilons coming out of
|
||||
/// a state, or a majority of the symbol table. Therefore we can
|
||||
/// either store nothing (for the all-epsilon case) or store a lookup
|
||||
/// table from Labels to arc offsets. Since the TableMatcher has to
|
||||
/// iterate over all arcs in each left-hand state the first time it sees
|
||||
/// it, this matcher type is not efficient if you compose with
|
||||
/// something very small on the right-- unless you do it multiple
|
||||
/// times and keep the matcher around. To do this requires using the
|
||||
/// most advanced form of ComposeFst in Compose.h, that initializes
|
||||
/// with ComposeFstImplOptions.
|
||||
|
||||
struct TableMatcherOptions {
|
||||
float
|
||||
table_ratio; // we construct the table if it would be at least this full.
|
||||
int min_table_size;
|
||||
TableMatcherOptions() : table_ratio(0.25), min_table_size(4) {}
|
||||
};
|
||||
|
||||
// Introducing an "impl" class for TableMatcher because
|
||||
// we need to do a shallow copy of the Matcher for when
|
||||
// we want to cache tables for multiple compositions.
|
||||
template <class F, class BackoffMatcher = SortedMatcher<F> >
|
||||
class TableMatcherImpl : public MatcherBase<typename F::Arc> {
|
||||
public:
|
||||
typedef F FST;
|
||||
typedef typename F::Arc Arc;
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef StateId
|
||||
ArcId; // Use this type to store arc offsets [it's actually size_t
|
||||
// in the Seek function of ArcIterator, but StateId should be big enough].
|
||||
typedef typename Arc::Weight Weight;
|
||||
|
||||
public:
|
||||
TableMatcherImpl(const FST &fst, MatchType match_type,
|
||||
const TableMatcherOptions &opts = TableMatcherOptions())
|
||||
: match_type_(match_type),
|
||||
fst_(fst.Copy()),
|
||||
loop_(match_type == MATCH_INPUT
|
||||
? Arc(kNoLabel, 0, Weight::One(), kNoStateId)
|
||||
: Arc(0, kNoLabel, Weight::One(), kNoStateId)),
|
||||
aiter_(NULL),
|
||||
s_(kNoStateId),
|
||||
opts_(opts),
|
||||
backoff_matcher_(fst, match_type) {
|
||||
assert(opts_.min_table_size > 0);
|
||||
if (match_type == MATCH_INPUT)
|
||||
assert(fst_->Properties(kILabelSorted, true) == kILabelSorted);
|
||||
else if (match_type == MATCH_OUTPUT)
|
||||
assert(fst_->Properties(kOLabelSorted, true) == kOLabelSorted);
|
||||
else
|
||||
assert(0 && "Invalid FST properties");
|
||||
}
|
||||
|
||||
virtual const FST &GetFst() const { return *fst_; }
|
||||
|
||||
virtual ~TableMatcherImpl() {
|
||||
std::vector<ArcId> *const empty =
|
||||
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
|
||||
for (size_t i = 0; i < tables_.size(); i++) {
|
||||
if (tables_[i] != NULL && tables_[i] != empty) delete tables_[i];
|
||||
}
|
||||
delete aiter_;
|
||||
delete fst_;
|
||||
}
|
||||
|
||||
virtual MatchType Type(bool test) const { return match_type_; }
|
||||
|
||||
void SetState(StateId s) {
|
||||
if (aiter_) {
|
||||
delete aiter_;
|
||||
aiter_ = NULL;
|
||||
}
|
||||
if (match_type_ == MATCH_NONE) LOG(FATAL) << "TableMatcher: bad match type";
|
||||
s_ = s;
|
||||
std::vector<ArcId> *const empty =
|
||||
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
|
||||
if (static_cast<size_t>(s) >= tables_.size()) {
|
||||
assert(s >= 0);
|
||||
tables_.resize(s + 1, NULL);
|
||||
}
|
||||
std::vector<ArcId> *&this_table_ = tables_[s]; // note: ref to ptr.
|
||||
if (this_table_ == empty) {
|
||||
backoff_matcher_.SetState(s);
|
||||
return;
|
||||
} else if (this_table_ == NULL) { // NULL means has not been set.
|
||||
ArcId num_arcs = fst_->NumArcs(s);
|
||||
if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
|
||||
this_table_ = empty;
|
||||
backoff_matcher_.SetState(s);
|
||||
return;
|
||||
}
|
||||
ArcIterator<FST> aiter(*fst_, s);
|
||||
aiter.SetFlags(
|
||||
kArcNoCache |
|
||||
(match_type_ == MATCH_OUTPUT ? kArcOLabelValue : kArcILabelValue),
|
||||
kArcNoCache | kArcValueFlags);
|
||||
// the statement above, says: "Don't cache stuff; and I only need the
|
||||
// ilabel/olabel to be computed.
|
||||
aiter.Seek(num_arcs - 1);
|
||||
Label highest_label =
|
||||
(match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
|
||||
: aiter.Value().ilabel);
|
||||
if ((highest_label + 1) * opts_.table_ratio > num_arcs) {
|
||||
this_table_ = empty;
|
||||
backoff_matcher_.SetState(s);
|
||||
return; // table would be too sparse.
|
||||
}
|
||||
// OK, now we are creating the table.
|
||||
this_table_ = new std::vector<ArcId>(highest_label + 1, kNoStateId);
|
||||
ArcId pos = 0;
|
||||
for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
|
||||
Label label = (match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
|
||||
: aiter.Value().ilabel);
|
||||
assert(static_cast<size_t>(label) <=
|
||||
static_cast<size_t>(highest_label)); // also checks >= 0.
|
||||
if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
|
||||
// set this_table_[label] to first position where arc has this
|
||||
// label.
|
||||
}
|
||||
}
|
||||
// At this point in the code, this_table_ != NULL and != empty.
|
||||
aiter_ = new ArcIterator<FST>(*fst_, s);
|
||||
aiter_->SetFlags(kArcNoCache,
|
||||
kArcNoCache); // don't need to cache arcs as may only
|
||||
// need a small subset.
|
||||
loop_.nextstate = s;
|
||||
// aiter_ = NULL;
|
||||
// backoff_matcher_.SetState(s);
|
||||
}
|
||||
|
||||
bool Find(Label match_label) {
|
||||
if (!aiter_) {
|
||||
return backoff_matcher_.Find(match_label);
|
||||
} else {
|
||||
match_label_ = match_label;
|
||||
current_loop_ = (match_label == 0);
|
||||
// kNoLabel means the implicit loop on the other FST --
|
||||
// matches real epsilons but not the self-loop.
|
||||
match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
|
||||
if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
|
||||
(*(tables_[s_]))[match_label_] != kNoStateId) {
|
||||
aiter_->Seek((*(tables_[s_]))[match_label_]); // label exists.
|
||||
return true;
|
||||
}
|
||||
return current_loop_;
|
||||
}
|
||||
}
|
||||
const Arc &Value() const {
|
||||
if (aiter_)
|
||||
return current_loop_ ? loop_ : aiter_->Value();
|
||||
else
|
||||
return backoff_matcher_.Value();
|
||||
}
|
||||
|
||||
void Next() {
|
||||
if (aiter_) {
|
||||
if (current_loop_)
|
||||
current_loop_ = false;
|
||||
else
|
||||
aiter_->Next();
|
||||
} else {
|
||||
backoff_matcher_.Next();
|
||||
}
|
||||
}
|
||||
|
||||
bool Done() const {
|
||||
if (aiter_ != NULL) {
|
||||
if (current_loop_) return false;
|
||||
if (aiter_->Done()) return true;
|
||||
Label label = (match_type_ == MATCH_OUTPUT ? aiter_->Value().olabel
|
||||
: aiter_->Value().ilabel);
|
||||
return (label != match_label_);
|
||||
} else {
|
||||
return backoff_matcher_.Done();
|
||||
}
|
||||
}
|
||||
const Arc &Value() {
|
||||
if (aiter_ != NULL) {
|
||||
return (current_loop_ ? loop_ : aiter_->Value());
|
||||
} else {
|
||||
return backoff_matcher_.Value();
|
||||
}
|
||||
}
|
||||
|
||||
virtual TableMatcherImpl<FST> *Copy(bool safe = false) const {
|
||||
assert(0); // shouldn't be called. This is not a "real" matcher,
|
||||
// although we derive from MatcherBase for convenience.
|
||||
return NULL;
|
||||
}
|
||||
|
||||
virtual uint64 Properties(uint64 props) const {
|
||||
return props;
|
||||
} // simple matcher that does
|
||||
// not change its FST, so properties are properties of FST it is applied to
|
||||
|
||||
private:
|
||||
virtual void SetState_(StateId s) { SetState(s); }
|
||||
virtual bool Find_(Label label) { return Find(label); }
|
||||
virtual bool Done_() const { return Done(); }
|
||||
virtual const Arc &Value_() const { return Value(); }
|
||||
virtual void Next_() { Next(); }
|
||||
|
||||
MatchType match_type_;
|
||||
FST *fst_;
|
||||
bool current_loop_;
|
||||
Label match_label_;
|
||||
Arc loop_;
|
||||
ArcIterator<FST> *aiter_;
|
||||
StateId s_;
|
||||
std::vector<std::vector<ArcId> *> tables_;
|
||||
TableMatcherOptions opts_;
|
||||
BackoffMatcher backoff_matcher_;
|
||||
};
|
||||
|
||||
template <class F, class BackoffMatcher = SortedMatcher<F> >
|
||||
class TableMatcher : public MatcherBase<typename F::Arc> {
|
||||
public:
|
||||
typedef F FST;
|
||||
typedef typename F::Arc Arc;
|
||||
typedef typename Arc::Label Label;
|
||||
typedef typename Arc::StateId StateId;
|
||||
typedef StateId
|
||||
ArcId; // Use this type to store arc offsets [it's actually size_t
|
||||
// in the Seek function of ArcIterator, but StateId should be big enough].
|
||||
typedef typename Arc::Weight Weight;
|
||||
typedef TableMatcherImpl<F, BackoffMatcher> Impl;
|
||||
|
||||
TableMatcher(const FST &fst, MatchType match_type,
|
||||
const TableMatcherOptions &opts = TableMatcherOptions())
|
||||
: impl_(std::make_shared<Impl>(fst, match_type, opts)) {}
|
||||
|
||||
TableMatcher(const TableMatcher<FST, BackoffMatcher> &matcher,
|
||||
bool safe = false)
|
||||
: impl_(matcher.impl_) {
|
||||
if (safe == true) {
|
||||
LOG(FATAL) << "TableMatcher: Safe copy not supported";
|
||||
}
|
||||
}
|
||||
|
||||
virtual const FST &GetFst() const { return impl_->GetFst(); }
|
||||
|
||||
virtual MatchType Type(bool test) const { return impl_->Type(test); }
|
||||
|
||||
void SetState(StateId s) { return impl_->SetState(s); }
|
||||
|
||||
bool Find(Label match_label) { return impl_->Find(match_label); }
|
||||
|
||||
const Arc &Value() const { return impl_->Value(); }
|
||||
|
||||
void Next() { return impl_->Next(); }
|
||||
|
||||
bool Done() const { return impl_->Done(); }
|
||||
|
||||
const Arc &Value() { return impl_->Value(); }
|
||||
|
||||
virtual TableMatcher<FST, BackoffMatcher> *Copy(bool safe = false) const {
|
||||
return new TableMatcher<FST, BackoffMatcher>(*this, safe);
|
||||
}
|
||||
|
||||
virtual uint64 Properties(uint64 props) const {
|
||||
return impl_->Properties(props);
|
||||
} // simple matcher that does
|
||||
// not change its FST, so properties are properties of FST it is applied to
|
||||
private:
|
||||
std::shared_ptr<Impl> impl_;
|
||||
|
||||
virtual void SetState_(StateId s) { impl_->SetState(s); }
|
||||
virtual bool Find_(Label label) { return impl_->Find(label); }
|
||||
virtual bool Done_() const { return impl_->Done(); }
|
||||
virtual const Arc &Value_() const { return impl_->Value(); }
|
||||
virtual void Next_() { impl_->Next(); }
|
||||
|
||||
TableMatcher &operator=(const TableMatcher &) = delete;
|
||||
};
|
||||
|
||||
struct TableComposeOptions : public TableMatcherOptions {
|
||||
bool connect; // Connect output
|
||||
ComposeFilter filter_type; // Which pre-defined filter to use
|
||||
MatchType table_match_type;
|
||||
|
||||
explicit TableComposeOptions(const TableMatcherOptions &mo, bool c = true,
|
||||
ComposeFilter ft = SEQUENCE_FILTER,
|
||||
MatchType tms = MATCH_OUTPUT)
|
||||
: TableMatcherOptions(mo),
|
||||
connect(c),
|
||||
filter_type(ft),
|
||||
table_match_type(tms) {}
|
||||
TableComposeOptions()
|
||||
: connect(true),
|
||||
filter_type(SEQUENCE_FILTER),
|
||||
table_match_type(MATCH_OUTPUT) {}
|
||||
};
|
||||
|
||||
template <class Arc>
|
||||
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
|
||||
MutableFst<Arc> *ofst,
|
||||
const TableComposeOptions &opts = TableComposeOptions()) {
|
||||
typedef Fst<Arc> F;
|
||||
CacheOptions nopts;
|
||||
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
|
||||
if (opts.table_match_type == MATCH_OUTPUT) {
|
||||
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
|
||||
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
|
||||
impl_opts.matcher1 = new TableMatcher<F>(ifst1, MATCH_OUTPUT, opts);
|
||||
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
|
||||
} else {
|
||||
assert(opts.table_match_type == MATCH_INPUT);
|
||||
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
|
||||
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
|
||||
impl_opts.matcher2 = new TableMatcher<F>(ifst2, MATCH_INPUT, opts);
|
||||
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
|
||||
}
|
||||
if (opts.connect) Connect(ofst);
|
||||
}
|
||||
|
||||
/// TableComposeCache lets us do multiple compositions while caching the same
|
||||
/// matcher.
|
||||
template <class F>
|
||||
struct TableComposeCache {
|
||||
TableMatcher<F> *matcher;
|
||||
TableComposeOptions opts;
|
||||
explicit TableComposeCache(
|
||||
const TableComposeOptions &opts = TableComposeOptions())
|
||||
: matcher(NULL), opts(opts) {}
|
||||
~TableComposeCache() { delete (matcher); }
|
||||
};
|
||||
|
||||
template <class Arc>
|
||||
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
|
||||
MutableFst<Arc> *ofst, TableComposeCache<Fst<Arc> > *cache) {
|
||||
typedef Fst<Arc> F;
|
||||
assert(cache != NULL);
|
||||
CacheOptions nopts;
|
||||
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
|
||||
if (cache->opts.table_match_type == MATCH_OUTPUT) {
|
||||
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
|
||||
if (cache->matcher == NULL)
|
||||
cache->matcher = new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
|
||||
impl_opts.matcher1 = cache->matcher->Copy(); // not passing "safe": may not
|
||||
// be thread-safe-- anway I don't understand this part.
|
||||
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
|
||||
} else {
|
||||
assert(cache->opts.table_match_type == MATCH_INPUT);
|
||||
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
|
||||
if (cache->matcher == NULL)
|
||||
cache->matcher = new TableMatcher<F>(ifst2, MATCH_INPUT, cache->opts);
|
||||
impl_opts.matcher2 = cache->matcher->Copy();
|
||||
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
|
||||
}
|
||||
if (cache->opts.connect) Connect(ofst);
|
||||
}
|
||||
|
||||
} // namespace fst
|
||||
|
||||
#endif // KALDI_FSTEXT_TABLE_MATCHER_H_
|
@ -0,0 +1,6 @@
|
||||
|
||||
add_library(kaldi-lat
|
||||
determinize-lattice-pruned.cc
|
||||
lattice-functions.cc
|
||||
)
|
||||
target_link_libraries(kaldi-lat PUBLIC kaldi-util)
|
@ -1,147 +0,0 @@
|
||||
// lat/determinize-lattice-pruned-test.cc
|
||||
|
||||
// Copyright 2009-2012 Microsoft Corporation
|
||||
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "lat/determinize-lattice-pruned.h"
|
||||
#include "fstext/lattice-utils.h"
|
||||
#include "fstext/fst-test-utils.h"
|
||||
#include "lat/kaldi-lattice.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
|
||||
namespace fst {
|
||||
// Caution: these tests are not as generic as you might think from all the
|
||||
// templates in the code. They are basically only valid for LatticeArc.
|
||||
// This is partly due to the fact that certain templates need to be instantiated
|
||||
// in other .cc files in this directory.
|
||||
|
||||
// test that determinization proceeds correctly on general
|
||||
// FSTs (not guaranteed determinzable, but we use the
|
||||
// max-states option to stop it getting out of control).
|
||||
template<class Arc> void TestDeterminizeLatticePruned() {
|
||||
typedef kaldi::int32 Int;
|
||||
typedef typename Arc::Weight Weight;
|
||||
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
RandFstOptions opts;
|
||||
opts.n_states = 4;
|
||||
opts.n_arcs = 10;
|
||||
opts.n_final = 2;
|
||||
opts.allow_empty = false;
|
||||
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
|
||||
opts.acyclic = true;
|
||||
// to be exactly representable in float,
|
||||
// or this test fails because numerical differences can cause symmetry in
|
||||
// weights to be broken, which causes the wrong path to be chosen as far
|
||||
// as the string part is concerned.
|
||||
|
||||
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||
|
||||
bool sorted = TopSort(fst);
|
||||
KALDI_ASSERT(sorted);
|
||||
|
||||
ILabelCompare<Arc> ilabel_comp;
|
||||
if (kaldi::Rand() % 2 == 0)
|
||||
ArcSort(fst, ilabel_comp);
|
||||
|
||||
std::cout << "FST before lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
VectorFst<Arc> det_fst;
|
||||
try {
|
||||
DeterminizeLatticePrunedOptions lat_opts;
|
||||
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
|
||||
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
|
||||
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
|
||||
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
|
||||
|
||||
std::cout << "FST after lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
|
||||
// OK, now determinize it a different way and check equivalence.
|
||||
// [note: it's not normal determinization, it's taking the best path
|
||||
// for any input-symbol sequence....
|
||||
|
||||
|
||||
VectorFst<Arc> pruned_fst(*fst);
|
||||
if (pruned_fst.NumStates() != 0)
|
||||
kaldi::PruneLattice(10.0, &pruned_fst);
|
||||
|
||||
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
|
||||
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
|
||||
std::cout << "Compact pruned FST is:\n";
|
||||
{
|
||||
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
|
||||
|
||||
std::cout << "Compact version of determinized FST is:\n";
|
||||
{
|
||||
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
|
||||
if (ans)
|
||||
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
|
||||
} catch (...) {
|
||||
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
|
||||
}
|
||||
delete fst;
|
||||
}
|
||||
}
|
||||
|
||||
// test that determinization proceeds without crash on acyclic FSTs
|
||||
// (guaranteed determinizable in this sense).
|
||||
template<class Arc> void TestDeterminizeLatticePruned2() {
|
||||
typedef typename Arc::Weight Weight;
|
||||
RandFstOptions opts;
|
||||
opts.acyclic = true;
|
||||
for(int i = 0; i < 100; i++) {
|
||||
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
|
||||
std::cout << "FST before lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
VectorFst<Arc> ofst;
|
||||
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
|
||||
std::cout << "FST after lattice-determinizing is:\n";
|
||||
{
|
||||
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
|
||||
fstprinter.Print(&std::cout, "standard output");
|
||||
}
|
||||
delete fst;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // end namespace fst
|
||||
|
||||
int main() {
|
||||
using namespace fst;
|
||||
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
|
||||
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
|
||||
std::cout << "Tests succeeded\n";
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,100 @@
|
||||
// fstbin/fstaddselfloops.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/determinize-star.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "util/parse-options.h"
|
||||
#include "util/simple-io-funcs.h"
|
||||
|
||||
/* some test examples:
|
||||
pushd ~/tmpdir
|
||||
( echo 3; echo 4) > in.list
|
||||
( echo 5; echo 6) > out.list
|
||||
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstaddselfloops in.list out.list
|
||||
| fstprint ( echo "0 1 0 1"; echo " 0 2 1 0"; echo "1 0"; echo "2 0"; ) |
|
||||
fstcompile | fstaddselfloops in.list out.list | fstprint
|
||||
*/
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi; // NOLINT
|
||||
using namespace fst; // NOLINT
|
||||
using kaldi::int32;
|
||||
|
||||
const char *usage =
|
||||
"Adds self-loops to states of an FST to propagate disambiguation "
|
||||
"symbols through it\n"
|
||||
"They are added on each final state and each state with non-epsilon "
|
||||
"output symbols\n"
|
||||
"on at least one arc out of the state. Useful in conjunction with "
|
||||
"predeterminize\n"
|
||||
"\n"
|
||||
"Usage: fstaddselfloops in-disambig-list out-disambig-list [in.fst "
|
||||
"[out.fst] ]\n"
|
||||
"E.g: fstaddselfloops in.list out.list < in.fst > withloops.fst\n"
|
||||
"in.list and out.list are lists of integers, one per line, of the\n"
|
||||
"same length.\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() < 2 || po.NumArgs() > 4) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string disambig_in_rxfilename = po.GetArg(1),
|
||||
disambig_out_rxfilename = po.GetArg(2),
|
||||
fst_in_filename = po.GetOptArg(3),
|
||||
fst_out_filename = po.GetOptArg(4);
|
||||
|
||||
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
|
||||
|
||||
std::vector<int32> disambig_in;
|
||||
if (!ReadIntegerVectorSimple(disambig_in_rxfilename, &disambig_in))
|
||||
KALDI_ERR
|
||||
<< "fstaddselfloops: Could not read disambiguation symbols from "
|
||||
<< kaldi::PrintableRxfilename(disambig_in_rxfilename);
|
||||
|
||||
std::vector<int32> disambig_out;
|
||||
if (!ReadIntegerVectorSimple(disambig_out_rxfilename, &disambig_out))
|
||||
KALDI_ERR
|
||||
<< "fstaddselfloops: Could not read disambiguation symbols from "
|
||||
<< kaldi::PrintableRxfilename(disambig_out_rxfilename);
|
||||
|
||||
if (disambig_in.size() != disambig_out.size())
|
||||
KALDI_ERR
|
||||
<< "fstaddselfloops: mismatch in size of disambiguation symbols";
|
||||
|
||||
AddSelfLoops(fst, disambig_in, disambig_out);
|
||||
|
||||
WriteFstKaldi(*fst, fst_out_filename);
|
||||
|
||||
delete fst;
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
// fstbin/fstdeterminizestar.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/determinize-star.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "util/parse-options.h"
|
||||
#if !defined(_MSC_VER) && !defined(__APPLE__)
|
||||
#include <signal.h> // Comment this line and the call to signal below if
|
||||
// it causes compilation problems. It is only to enable a debugging procedure
|
||||
// when determinization does not terminate. We are disabling this code if
|
||||
// compiling on Windows because signal.h is not available there, and on
|
||||
// MacOS due to a problem with <signal.h> in the initial release of Sierra.
|
||||
#endif
|
||||
|
||||
/* some test examples:
|
||||
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
|
||||
( echo "0 0 1 0"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
|
||||
( echo "0 0 1 0"; echo "0 1 1 0"; echo "0 0" ) | fstcompile |
|
||||
fstdeterminizestar | fstprint # this last one fails [correctly]: ( echo "0 0 0
|
||||
1"; echo "0 0" ) | fstcompile | fstdeterminizestar | fstprint
|
||||
|
||||
cd ~/tmpdir
|
||||
while true; do
|
||||
fstrand > 1.fst
|
||||
fstpredeterminize out.lst 1.fst | fstdeterminizestar | fstrmsymbols out.lst
|
||||
> 2.fst fstequivalent --random=true 1.fst 2.fst || echo "Test failed" echo -n
|
||||
"." done
|
||||
|
||||
Test of debugging [with non-determinizable input]:
|
||||
( echo " 0 0 1 0 1.0"; echo "0 1 1 0"; echo "1 1 1 0 0"; echo "0 2 2 0"; echo
|
||||
"2"; echo "1" ) | fstcompile | fstdeterminizestar kill -SIGUSR1 [the process-id
|
||||
of fstdeterminizestar] # prints out a bunch of debugging output showing the
|
||||
mess it got itself into.
|
||||
*/
|
||||
|
||||
bool debug_location = false;
|
||||
void signal_handler(int) { debug_location = true; }
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi; // NOLINT
|
||||
using namespace fst; // NOLINT
|
||||
using kaldi::int32;
|
||||
|
||||
const char *usage =
|
||||
"Removes epsilons and determinizes in one step\n"
|
||||
"\n"
|
||||
"Usage: fstdeterminizestar [in.fst [out.fst] ]\n"
|
||||
"\n"
|
||||
"See also: fstdeterminizelog, lattice-determinize\n";
|
||||
|
||||
float delta = kDelta;
|
||||
int max_states = -1;
|
||||
bool use_log = false;
|
||||
ParseOptions po(usage);
|
||||
po.Register("use-log", &use_log, "Determinize in log semiring.");
|
||||
po.Register("delta", &delta,
|
||||
"Delta value used to determine equivalence of weights.");
|
||||
po.Register(
|
||||
"max-states", &max_states,
|
||||
"Maximum number of states in determinized FST before it will abort.");
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() > 2) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string fst_in_str = po.GetOptArg(1), fst_out_str = po.GetOptArg(2);
|
||||
|
||||
// This enables us to get traceback info from determinization that is
|
||||
// not seeming to terminate.
|
||||
#if !defined(_MSC_VER) && !defined(__APPLE__)
|
||||
signal(SIGUSR1, signal_handler);
|
||||
#endif
|
||||
// Normal case: just files.
|
||||
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_str);
|
||||
|
||||
ArcSort(fst, ILabelCompare<StdArc>()); // improves speed.
|
||||
if (use_log) {
|
||||
DeterminizeStarInLog(fst, delta, &debug_location, max_states);
|
||||
} else {
|
||||
VectorFst<StdArc> det_fst;
|
||||
DeterminizeStar(*fst, &det_fst, delta, &debug_location, max_states);
|
||||
*fst = det_fst; // will do shallow copy and then det_fst goes
|
||||
// out of scope anyway.
|
||||
}
|
||||
WriteFstKaldi(*fst, fst_out_str);
|
||||
delete fst;
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
// fstbin/fstisstochastic.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "util/kaldi-io.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
// e.g. of test:
|
||||
// echo " 0 0" | fstcompile | fstisstochastic
|
||||
// should return 0 and print "0 0" [meaning, min and
|
||||
// max weight are one = exp(0)]
|
||||
// echo " 0 1" | fstcompile | fstisstochastic
|
||||
// should return 1, not stochastic, and print 1 1
|
||||
// (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo "1 0" ) |
|
||||
// fstcompile | fstisstochastic should return 0, stochastic; it prints "0
|
||||
// -1.78e-07" for me (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo
|
||||
// "1 0" ) | fstcompile | fstisstochastic --test-in-log=false should return 1,
|
||||
// not stochastic in tropical; it prints "0 0.693147" for me (echo "0 0 0 0 0 ";
|
||||
// echo "0 1 0 0 0 "; echo "1 0" ) | fstcompile | fstisstochastic
|
||||
// --test-in-log=false should return 0, stochastic in tropical; it prints "0 0"
|
||||
// for me (echo "0 0 0 0 0.693147 "; echo "0 1 0 0 0.693147 "; echo "1 0" ) |
|
||||
// fstcompile | fstisstochastic --test-in-log=false --delta=1 returns 0 even
|
||||
// though not stochastic because we gave it an absurdly large delta.
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi; // NOLINT
|
||||
using namespace fst; // NOLINT
|
||||
using kaldi::int32;
|
||||
|
||||
const char *usage =
|
||||
"Checks whether an FST is stochastic and exits with success if so.\n"
|
||||
"Prints out maximum error (in log units).\n"
|
||||
"\n"
|
||||
"Usage: fstisstochastic [ in.fst ]\n";
|
||||
|
||||
float delta = 0.01;
|
||||
bool test_in_log = true;
|
||||
|
||||
ParseOptions po(usage);
|
||||
po.Register("delta", &delta, "Maximum error to accept.");
|
||||
po.Register("test-in-log", &test_in_log,
|
||||
"Test stochasticity in log semiring.");
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() > 1) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string fst_in_filename = po.GetOptArg(1);
|
||||
|
||||
Fst<StdArc> *fst = ReadFstKaldiGeneric(fst_in_filename);
|
||||
|
||||
bool ans;
|
||||
StdArc::Weight min, max;
|
||||
if (test_in_log)
|
||||
ans = IsStochasticFstInLog(*fst, delta, &min, &max);
|
||||
else
|
||||
ans = IsStochasticFst(*fst, delta, &min, &max);
|
||||
|
||||
std::cout << min.Value() << " " << max.Value() << '\n';
|
||||
delete fst;
|
||||
if (ans)
|
||||
return 0; // success;
|
||||
else
|
||||
return 1;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
// fstbin/fstminimizeencoded.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/determinize-star.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "util/kaldi-io.h"
|
||||
#include "util/parse-options.h"
|
||||
#include "util/text-utils.h"
|
||||
|
||||
/* some test examples:
|
||||
( echo "0 0 0 0"; echo "0 0" ) | fstcompile | fstminimizeencoded | fstprint
|
||||
( echo "0 1 0 0"; echo " 0 2 0 0"; echo "1 0"; echo "2 0"; ) | fstcompile |
|
||||
fstminimizeencoded | fstprint
|
||||
*/
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi; // NOLINT
|
||||
using namespace fst; // NOLINT
|
||||
using kaldi::int32;
|
||||
|
||||
const char *usage =
|
||||
"Minimizes FST after encoding [similar to fstminimize, but no "
|
||||
"weight-pushing]\n"
|
||||
"\n"
|
||||
"Usage: fstminimizeencoded [in.fst [out.fst] ]\n";
|
||||
|
||||
float delta = kDelta;
|
||||
ParseOptions po(usage);
|
||||
po.Register("delta", &delta,
|
||||
"Delta likelihood used for quantization of weights");
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() > 2) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string fst_in_filename = po.GetOptArg(1),
|
||||
fst_out_filename = po.GetOptArg(2);
|
||||
|
||||
VectorFst<StdArc> *fst = ReadFstKaldi(fst_in_filename);
|
||||
|
||||
MinimizeEncoded(fst, delta);
|
||||
|
||||
WriteFstKaldi(*fst, fst_out_filename);
|
||||
|
||||
delete fst;
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
@ -0,0 +1,133 @@
|
||||
// fstbin/fsttablecompose.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2013 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fstext/fstext-utils.h"
|
||||
#include "fstext/kaldi-fst-io.h"
|
||||
#include "fstext/table-matcher.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
/*
|
||||
cd ~/tmpdir
|
||||
while true; do
|
||||
fstrand | fstarcsort --sort_type=olabel > 1.fst; fstrand | fstarcsort
|
||||
> 2.fst fstcompose 1.fst 2.fst > 3a.fst fsttablecompose 1.fst 2.fst > 3b.fst
|
||||
fstequivalent --random=true 3a.fst 3b.fst || echo "Test failed"
|
||||
echo -n "."
|
||||
done
|
||||
|
||||
*/
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi; // NOLINT
|
||||
using namespace fst; // NOLINT
|
||||
using kaldi::int32;
|
||||
/*
|
||||
fsttablecompose should always give equivalent results to compose,
|
||||
but it is more efficient for certain kinds of inputs.
|
||||
In particular, it is useful when, say, the left FST has states
|
||||
that typically either have epsilon olabels, or
|
||||
one transition out for each of the possible symbols (as the
|
||||
olabel). The same with the input symbols of the right-hand FST
|
||||
is possible.
|
||||
*/
|
||||
|
||||
const char *usage =
|
||||
"Composition algorithm [between two FSTs of standard type, in "
|
||||
"tropical\n"
|
||||
"semiring] that is more efficient for certain cases-- in particular,\n"
|
||||
"where one of the FSTs (the left one, if --match-side=left) has large\n"
|
||||
"out-degree\n"
|
||||
"\n"
|
||||
"Usage: fsttablecompose (fst1-rxfilename|fst1-rspecifier) "
|
||||
"(fst2-rxfilename|fst2-rspecifier) [(out-rxfilename|out-rspecifier)]\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
|
||||
TableComposeOptions opts;
|
||||
std::string match_side = "left";
|
||||
std::string compose_filter = "sequence";
|
||||
|
||||
po.Register("connect", &opts.connect, "If true, trim FST before output.");
|
||||
po.Register("match-side", &match_side,
|
||||
"Side of composition to do table "
|
||||
"match, one of: \"left\" or \"right\".");
|
||||
po.Register("compose-filter", &compose_filter,
|
||||
"Composition filter to use, "
|
||||
"one of: \"alt_sequence\", \"auto\", \"match\", \"sequence\"");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (match_side == "left") {
|
||||
opts.table_match_type = MATCH_OUTPUT;
|
||||
} else if (match_side == "right") {
|
||||
opts.table_match_type = MATCH_INPUT;
|
||||
} else {
|
||||
KALDI_ERR << "Invalid match-side option: " << match_side;
|
||||
}
|
||||
|
||||
if (compose_filter == "alt_sequence") {
|
||||
opts.filter_type = ALT_SEQUENCE_FILTER;
|
||||
} else if (compose_filter == "auto") {
|
||||
opts.filter_type = AUTO_FILTER;
|
||||
} else if (compose_filter == "match") {
|
||||
opts.filter_type = MATCH_FILTER;
|
||||
} else if (compose_filter == "sequence") {
|
||||
opts.filter_type = SEQUENCE_FILTER;
|
||||
} else {
|
||||
KALDI_ERR << "Invalid compose-filter option: " << compose_filter;
|
||||
}
|
||||
|
||||
if (po.NumArgs() < 2 || po.NumArgs() > 3) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string fst1_in_str = po.GetArg(1), fst2_in_str = po.GetArg(2),
|
||||
fst_out_str = po.GetOptArg(3);
|
||||
|
||||
VectorFst<StdArc> *fst1 = ReadFstKaldi(fst1_in_str);
|
||||
|
||||
VectorFst<StdArc> *fst2 = ReadFstKaldi(fst2_in_str);
|
||||
|
||||
// Checks if <fst1> is olabel sorted and <fst2> is ilabel sorted.
|
||||
if (fst1->Properties(fst::kOLabelSorted, true) == 0) {
|
||||
KALDI_WARN << "The first FST is not olabel sorted.";
|
||||
}
|
||||
if (fst2->Properties(fst::kILabelSorted, true) == 0) {
|
||||
KALDI_WARN << "The second FST is not ilabel sorted.";
|
||||
}
|
||||
|
||||
VectorFst<StdArc> composed_fst;
|
||||
|
||||
TableCompose(*fst1, *fst2, &composed_fst, opts);
|
||||
|
||||
delete fst1;
|
||||
delete fst2;
|
||||
|
||||
WriteFstKaldi(composed_fst, fst_out_str);
|
||||
return 0;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
current_path=`pwd`
|
||||
current_dir=`basename "$current_path"`
|
||||
|
||||
if [ "tools" != "$current_dir" ]; then
|
||||
echo "You should run this script in tools/ directory!!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d liblbfgs-1.10 ]; then
|
||||
echo Installing libLBFGS library to support MaxEnt LMs
|
||||
bash extras/install_liblbfgs.sh || exit 1
|
||||
fi
|
||||
|
||||
! command -v gawk > /dev/null && \
|
||||
echo "GNU awk is not installed so SRILM will probably not work correctly: refusing to install" && exit 1;
|
||||
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "SRILM download requires some information about you"
|
||||
echo
|
||||
echo "Usage: $0 <name> <organization> <email>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
srilm_url="http://www.speech.sri.com/projects/srilm/srilm_download.php"
|
||||
post_data="WWW_file=srilm-1.7.3.tar.gz&WWW_name=$1&WWW_org=$2&WWW_email=$3"
|
||||
|
||||
if ! wget --post-data "$post_data" -O ./srilm.tar.gz "$srilm_url"; then
|
||||
echo 'There was a problem downloading the file.'
|
||||
echo 'Check you internet connection and try again.'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p srilm
|
||||
cd srilm
|
||||
|
||||
|
||||
if [ -f ../srilm.tgz ]; then
|
||||
tar -xvzf ../srilm.tgz # Old SRILM format
|
||||
elif [ -f ../srilm.tar.gz ]; then
|
||||
tar -xvzf ../srilm.tar.gz # Changed format type from tgz to tar.gz
|
||||
fi
|
||||
|
||||
major=`gawk -F. '{ print $1 }' RELEASE`
|
||||
minor=`gawk -F. '{ print $2 }' RELEASE`
|
||||
micro=`gawk -F. '{ print $3 }' RELEASE`
|
||||
|
||||
if [ $major -le 1 ] && [ $minor -le 7 ] && [ $micro -le 1 ]; then
|
||||
echo "Detected version 1.7.1 or earlier. Applying patch."
|
||||
patch -p0 < ../extras/srilm.patch
|
||||
fi
|
||||
|
||||
# set the SRILM variable in the top-level Makefile to this directory.
|
||||
cp Makefile tmpf
|
||||
|
||||
cat tmpf | gawk -v pwd=`pwd` '/SRILM =/{printf("SRILM = %s\n", pwd); next;} {print;}' \
|
||||
> Makefile || exit 1
|
||||
rm tmpf
|
||||
|
||||
mtype=`sbin/machine-type`
|
||||
|
||||
echo HAVE_LIBLBFGS=1 >> common/Makefile.machine.$mtype
|
||||
grep ADDITIONAL_INCLUDES common/Makefile.machine.$mtype | \
|
||||
sed 's|$| -I$(SRILM)/../liblbfgs-1.10/include|' \
|
||||
>> common/Makefile.machine.$mtype
|
||||
|
||||
grep ADDITIONAL_LDFLAGS common/Makefile.machine.$mtype | \
|
||||
sed 's|$| -L$(SRILM)/../liblbfgs-1.10/lib/ -Wl,-rpath -Wl,$(SRILM)/../liblbfgs-1.10/lib/|' \
|
||||
>> common/Makefile.machine.$mtype
|
||||
|
||||
make || exit
|
||||
|
||||
cd ..
|
||||
(
|
||||
[ ! -z "${SRILM}" ] && \
|
||||
echo >&2 "SRILM variable is aleady defined. Undefining..." && \
|
||||
unset SRILM
|
||||
|
||||
[ -f ./env.sh ] && . ./env.sh
|
||||
|
||||
[ ! -z "${SRILM}" ] && \
|
||||
echo >&2 "SRILM config is already in env.sh" && exit
|
||||
|
||||
wd=`pwd`
|
||||
wd=`readlink -f $wd || pwd`
|
||||
|
||||
echo "export SRILM=$wd/srilm"
|
||||
dirs="\${PATH}"
|
||||
for directory in $(cd srilm && find bin -type d ) ; do
|
||||
dirs="$dirs:\${SRILM}/$directory"
|
||||
done
|
||||
echo "export PATH=$dirs"
|
||||
) >> env.sh
|
||||
|
||||
echo >&2 "Installation of SRILM finished successfully"
|
||||
echo >&2 "Please source the tools/env.sh in your path.sh to enable it"
|
@ -0,0 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
|
||||
|
||||
add_executable(arpa2fst ${CMAKE_CURRENT_SOURCE_DIR}/arpa2fst.cc)
|
||||
target_include_directories(arpa2fst PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
|
||||
target_link_libraries(arpa2fst )
|
@ -0,0 +1,145 @@
|
||||
// bin/arpa2fst.cc
|
||||
//
|
||||
// Copyright 2009-2011 Gilles Boulianne.
|
||||
//
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABILITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "lm/arpa-lm-compiler.h"
|
||||
#include "util/kaldi-io.h"
|
||||
#include "util/parse-options.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
using namespace kaldi; // NOLINT
|
||||
try {
|
||||
const char *usage =
|
||||
"Convert an ARPA format language model into an FST\n"
|
||||
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
|
||||
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
|
||||
"data/lang/words.txt lm/input.arpa G.fst\n\n"
|
||||
"Note: When called without switches, the output G.fst will contain\n"
|
||||
"an embedded symbol table. This is compatible with the way a previous\n"
|
||||
"version of arpa2fst worked.\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
|
||||
ArpaParseOptions options;
|
||||
options.Register(&po);
|
||||
|
||||
// Option flags.
|
||||
std::string bos_symbol = "<s>";
|
||||
std::string eos_symbol = "</s>";
|
||||
std::string disambig_symbol;
|
||||
std::string read_syms_filename;
|
||||
std::string write_syms_filename;
|
||||
bool keep_symbols = false;
|
||||
bool ilabel_sort = true;
|
||||
|
||||
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
|
||||
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
|
||||
po.Register("disambig-symbol", &disambig_symbol,
|
||||
"Disambiguator. If provided (e. g. #0), used on input side of "
|
||||
"backoff links, and <s> and </s> are replaced with epsilons");
|
||||
po.Register("read-symbol-table", &read_syms_filename,
|
||||
"Use existing symbol table");
|
||||
po.Register("write-symbol-table", &write_syms_filename,
|
||||
"Write generated symbol table to a file");
|
||||
po.Register("keep-symbols", &keep_symbols,
|
||||
"Store symbol table with FST. Symbols always saved to FST if "
|
||||
"symbol tables are neither read or written (otherwise symbols "
|
||||
"would be lost entirely)");
|
||||
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
std::string arpa_rxfilename = po.GetArg(1),
|
||||
fst_wxfilename = po.GetOptArg(2);
|
||||
|
||||
int64 disambig_symbol_id = 0;
|
||||
|
||||
fst::SymbolTable *symbols;
|
||||
if (!read_syms_filename.empty()) {
|
||||
// Use existing symbols. Required symbols must be in the table.
|
||||
kaldi::Input kisym(read_syms_filename);
|
||||
symbols = fst::SymbolTable::ReadText(
|
||||
kisym.Stream(), PrintableWxfilename(read_syms_filename));
|
||||
if (symbols == NULL)
|
||||
KALDI_ERR << "Could not read symbol table from file "
|
||||
<< read_syms_filename;
|
||||
|
||||
options.oov_handling = ArpaParseOptions::kSkipNGram;
|
||||
if (!disambig_symbol.empty()) {
|
||||
disambig_symbol_id = symbols->Find(disambig_symbol);
|
||||
if (disambig_symbol_id == -1) // fst::kNoSymbol
|
||||
KALDI_ERR << "Symbol table " << read_syms_filename
|
||||
<< " has no symbol for " << disambig_symbol;
|
||||
}
|
||||
} else {
|
||||
// Create a new symbol table and populate it from ARPA file.
|
||||
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
|
||||
options.oov_handling = ArpaParseOptions::kAddToSymbols;
|
||||
symbols->AddSymbol("<eps>", 0);
|
||||
if (!disambig_symbol.empty()) {
|
||||
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
|
||||
}
|
||||
}
|
||||
|
||||
// Add or use existing BOS and EOS.
|
||||
options.bos_symbol = symbols->AddSymbol(bos_symbol);
|
||||
options.eos_symbol = symbols->AddSymbol(eos_symbol);
|
||||
|
||||
// If producing new (not reading existing) symbols and not saving them,
|
||||
// need to keep symbols with FST, otherwise they would be lost.
|
||||
if (read_syms_filename.empty() && write_syms_filename.empty())
|
||||
keep_symbols = true;
|
||||
|
||||
// Actually compile LM.
|
||||
KALDI_ASSERT(symbols != NULL);
|
||||
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
|
||||
{
|
||||
Input ki(arpa_rxfilename);
|
||||
lm_compiler.Read(ki.Stream());
|
||||
}
|
||||
|
||||
// Sort the FST in-place if requested by options.
|
||||
if (ilabel_sort) {
|
||||
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
|
||||
}
|
||||
|
||||
// Write symbols if requested.
|
||||
if (!write_syms_filename.empty()) {
|
||||
kaldi::Output kosym(write_syms_filename, false);
|
||||
symbols->WriteText(kosym.Stream());
|
||||
}
|
||||
|
||||
// Write LM FST.
|
||||
bool write_binary = true, write_header = false;
|
||||
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
|
||||
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
|
||||
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
|
||||
lm_compiler.Fst().Write(kofst.Stream(), wopts);
|
||||
|
||||
delete symbols;
|
||||
} catch (const std::exception &e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
}
|
Loading…
Reference in new issue