add decodable & ctc_beam_search_deocder

pull/1400/head
SmileGoat 3 years ago
parent e57efcb314
commit d14ee80065

@ -0,0 +1,7 @@
#include "base/basic_types.h"
struct DecoderResult {
BaseFloat acoustic_score;
std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp;
};

@ -0,0 +1,264 @@
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "decoder/ctc_decoders/decoder_utils.h"
namespace ppspeech {
using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
opts_(opts),
vocabulary_(nullptr),
init_ext_scorer_(nullptr),
blank_id(-1),
space_id(-1),
root(nullptr) {
LOG(INFO) << "dict path: " << _opts.dict_file;
vocabulary_ = std::make_shared<vector<string>>();
if (!basr::ReadDictToVector(_opts.dict_file, *vocabulary_)) {
LOG(INFO) << "load the dict failed";
}
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
LOG(INFO) << "language model path: " << _opts.lm_path;
init_ext_scorer_ = std::make_shared<Scorer>(_opts.alpha,
_opts.beta,
_opts.lm_path,
*vocabulary_);
}
void CTCBeamSearch::InitDecoder() {
blank_id = 0;
auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " ");
space_id = it - vocabulary_->begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary_->size()) {
space_id = -2;
}
clear_prefixes();
root = std::make_shared<PathTrie>();
root->score = root->log_prob_b_prev = 0.0;
prefixes.push_back(root.get());
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(init_ext_scorer_->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void CTCBeamSearch::ResetPrefixes() {
for (size_t i = 0; i < prefixes.size(); i++) {
if (prefixes[i] != nullptr) {
delete prefixes[i];
prefixes[i] = nullptr;
}
}
}
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
vector<string>& nbest_words) {
std::thread::id this_id = std::this_thread::get_id();
Timer timer;
vector<vector<double>> double_probs(probs.size(), vector<double>(probs[0].size(), 0));
int row = probs.size();
int col = probs[0].size();
for(int i = 0; i < row; i++) {
for (int j = 0; j < col; j++){
double_probs[i][j] = static_cast<double>(probs[i][j]);
}
}
timer.Reset();
vector<std::pair<double, string>> results = AdvanceDecoding(double_probs);
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
for (const auto& item : results) {
nbest_words.push_back(item.second);
}
return 0;
}
vector<std::pair<double, string>> CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
size_t num_time_steps = probs_seq.size();
size_t beam_size = _opts.beam_size;
double cutoff_prob = _opts.cutoff_prob;
size_t cutoff_top_n = _opts.cutoff_top_n;
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
const auto& prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (init_ext_scorer_ != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
prefix_compare);
if (num_prefixes == 0) {
continue;
}
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes == beam_size);
}
vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
size_t log_prob_idx_len = log_prob_idx.size();
for (size_t index = 0; index < log_prob_idx_len; index++) {
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
} // if
} // for probs_seq
// score the last word of each prefix that doesn't end with space
LMRescore();
CalculateApproxScore();
return get_beam_search_result(prefixes, *vocabulary_, beam_size);
}
int CTCBeamSearch::SearchOneChar(const bool& full_beam,
const std::pair<size_t, float>& log_prob_idx,
const float& min_cutoff) {
size_t beam_size = _opts.beam_size;
const auto& c = log_prob_idx.first;
const auto& log_prob_c = log_prob_idx.second;
size_t prefixes_len = std::min(prefixes.size(), beam_size);
for (size_t i = 0; i < prefixes_len; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur,
log_prob_c +
prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
// p_{nb}(l;x_{1:t}) = p(c;x_{t})p(l;x_{1:t-1})
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur,
log_prob_c +
prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t})p_{b}(l;x_{1:t-1})
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
// p_{nb}(l^{+};x_{1:t}) = p(c;x_{t}) p(l;x_{1:t-1})
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (init_ext_scorer_ != nullptr &&
(c == space_id || init_ext_scorer_->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (init_ext_scorer_->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
vector<string> ngram;
ngram = init_ext_scorer_->make_ngram(prefix_to_score);
// lm score: p_{lm}(W)^{\alpha} + \beta
score = init_ext_scorer_->get_log_cond_prob(ngram) *
init_ext_scorer_->alpha;
log_p += score;
log_p += init_ext_scorer_->beta;
}
// p_{nb}(l;x_{1:t})
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur,
log_p);
}
} // end of loop over prefix
return 0;
}
void CTCBeamSearch::CalculateApproxScore() {
size_t beam_size = _opts.beam_size;
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (init_ext_scorer_ != nullptr) {
vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = init_ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * init_ext_scorer_->beta;
// remove language model weight:
approx_ctc -=
(init_ext_scorer_->get_sent_log_prob(words)) * init_ext_scorer_->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
}
void CTCBeamSearch::LMRescore() {
size_t beam_size = _opts.beam_size;
if (init_ext_scorer_ != nullptr && !init_ext_scorer_->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
score = init_ext_scorer_->get_log_cond_prob(ngram) * init_ext_scorer_->alpha;
score += init_ext_scorer_->beta;
prefix->score += score;
}
}
}
}
} // namespace ppspeech

@ -0,0 +1,74 @@
#include "base/basic_types.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions() :
dict_file("./model/words.txt"),
lm_path("./model/lm.arpa"),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(0) {
}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register("beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register("num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public:
CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts);
~CTCBeamSearch() {
}
bool InitDecoder();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>&probs,
std::vector<std::string>& nbest_words);
std::vector<DecodeResult>& GetDecodeResult() {
return decoder_results_;
}
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
std::vector<std::pair<double, std::string>>
AdvanceDecoding(const std::vector<std::vector<double>>& probs_seq);
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<DecodeResult> decoder_results_;
std::vector<std::vector<std::string>> vocabulary_; // todo remove later
size_t blank_id;
int space_id;
std::shared_ptr<PathTrie> root;
std::vector<PathTrie*> prefixes;
};
} // namespace basr

@ -0,0 +1 @@
../../../third_party/ctc_decoders

File diff suppressed because it is too large Load Diff

@ -0,0 +1,549 @@
// decoder/lattice-faster-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include "decoder/grammar-fst.h"
#include "fst/fstlib.h"
#include "fst/memory.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
#include "util/stl-utils.h"
namespace kaldi {
struct LatticeFasterDecoderConfig {
BaseFloat beam;
int32 max_active;
int32 min_active;
BaseFloat lattice_beam;
int32 prune_interval;
bool determinize_lattice; // not inspected by this class... used in
// command-line program.
BaseFloat beam_delta;
BaseFloat hash_ratio;
// Note: we don't make prune_scale configurable on the command line, it's not
// a very important parameter. It affects the algorithm that prunes the
// tokens as we go.
BaseFloat prune_scale;
// Number of elements in the block for Token and ForwardLink memory
// pool allocation.
int32 memory_pool_tokens_block_size;
int32 memory_pool_links_block_size;
// Most of the options inside det_opts are not actually queried by the
// LatticeFasterDecoder class itself, but by the code that calls it, for
// example in the function DecodeUtteranceLatticeFaster.
fst::DeterminizeLatticePhonePrunedOptions det_opts;
LatticeFasterDecoderConfig()
: beam(16.0),
max_active(std::numeric_limits<int32>::max()),
min_active(200),
lattice_beam(10.0),
prune_interval(25),
determinize_lattice(true),
beam_delta(0.5),
hash_ratio(2.0),
prune_scale(0.1),
memory_pool_tokens_block_size(1 << 8),
memory_pool_links_block_size(1 << 8) {}
void Register(OptionsItf *opts) {
det_opts.Register(opts);
opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate.");
opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; "
"more accurate");
opts->Register("min-active", &min_active, "Decoder minimum #active states.");
opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, "
"and deeper lattices");
opts->Register("prune-interval", &prune_interval, "Interval (in frames) at "
"which to prune tokens");
opts->Register("determinize-lattice", &determinize_lattice, "If true, "
"determinize the lattice (lattice-determinization, keeping only "
"best pdf-sequence for each word-sequence).");
opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this "
"parameter is obscure and relates to a speedup in the way the "
"max-active constraint is applied. Larger is more accurate.");
opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to "
"control hash behavior");
opts->Register("memory-pool-tokens-block-size", &memory_pool_tokens_block_size,
"Memory pool block size suggestion for storing tokens (in elements). "
"Smaller uses less memory but increases cache misses.");
opts->Register("memory-pool-links-block-size", &memory_pool_links_block_size,
"Memory pool block size suggestion for storing links (in elements). "
"Smaller uses less memory but increases cache misses.");
}
void Check() const {
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0
&& min_active <= max_active
&& prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0
&& prune_scale > 0.0 && prune_scale < 1.0);
}
};
namespace decoder {
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template <typename Token>
struct ForwardLink {
using Label = fst::StdArc::Label;
Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on arc
Label olabel; // olabel on arc
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
ForwardLink *next):
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
next(next) { }
};
struct StdToken {
using ForwardLinkT = ForwardLink<StdToken>;
using Token = StdToken;
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
// minimum difference between the cost of the best path that this link is a
// part of, and the cost of the absolute best path, under the assumption that
// any of the currently active states at the decoding front may eventually
// succeed (e.g. if you were to take the currently active states one by one
// and compute this difference, and then take the minimum).
BaseFloat extra_cost;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
//'next' is the next in the singly-linked list of tokens for this frame.
Token *next;
// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline void SetBackpointer (Token *backpointer) { }
// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
};
struct BackpointerToken {
using ForwardLinkT = ForwardLink<BackpointerToken>;
using Token = BackpointerToken;
// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat extra_cost;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
//'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken *next;
// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token *backpointer;
inline void SetBackpointer (Token *backpointer) {
this->backpointer = backpointer;
}
inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
backpointer(backpointer) { }
};
} // namespace decoder
/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
for more information.
The decoder is templated on the FST type and the token type. The token type
will normally be StdToken, but also may be BackpointerToken which is to support
quick lookup of the current best path (see lattice-faster-online-decoder.h)
The FST you invoke this decoder which is expected to equal
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
FST == StdFst and it notices that the actual FST type is
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
will internally cast itself to one that is templated on those more specific
types; this is an optimization for speed.
*/
template <typename FST, typename Token = decoder::StdToken>
class LatticeFasterDecoderTpl {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config);
// This version of the constructor takes ownership of the fst, and will delete
// it when this object is destroyed.
LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst);
void SetOptions(const LatticeFasterDecoderConfig &config) {
config_ = config;
}
const LatticeFasterDecoderConfig &GetOptions() const {
return config_;
}
~LatticeFasterDecoderTpl();
/// Decodes until there are no more frames left in the "decodable" object..
/// note, this may block waiting for input if the "decodable" object blocks.
/// Returns true if any kind of traceback is available (not necessarily from a
/// final state).
bool Decode(DecodableInterface *decodable);
/// says whether a final-state was active on the last frame. If it was not, the
/// lattice (or traceback) will end with states that are not final-states.
bool ReachedFinal() const {
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
}
/// Outputs an FST corresponding to the single best path through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one. Note: this just calls GetRawLattice()
/// and figures out the shortest path.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// Outputs an FST corresponding to the raw, state-level
/// tracebacks. Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
/// [Deprecated, users should now use GetRawLattice and determinize it
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
/// Outputs an FST corresponding to the lattice-determinized
/// lattice (one path per word sequence). Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state of the graph
/// then it will include those as final-probs, else it will treat all
/// final-probs as one.
bool GetLattice(CompactLattice *ofst,
bool use_final_probs = true) const;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
/// This will decode until there are no more frames ready in the decodable
/// object. You can keep calling it each time more frames become available.
/// If max_num_frames is specified, it specifies the maximum number of frames
/// the function will decode before returning.
void AdvanceDecoding(DecodableInterface *decodable,
int32 max_num_frames = -1);
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more completely, particularly toward the end of the
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
/// will fail), and you cannot call GetLattice() and related functions with
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
void FinalizeDecoding();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat FinalRelativeCost() const;
// Returns the number of frames decoded so far. The value returned changes
// whenever we call ProcessEmitting().
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.
// Deletes the elements of the singly linked list tok->links.
void DeleteForwardLinks(Token *tok);
// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct TokenList {
Token *toks;
bool must_prune_forward_links;
bool must_prune_tokens;
TokenList(): toks(NULL), must_prune_forward_links(true),
must_prune_tokens(true) { }
};
using Elem = typename HashList<StateId, Token*>::Elem;
// Equivalent to:
// struct Elem {
// StateId key;
// Token *val;
// Elem *tail;
// };
void PossiblyResizeHash(size_t num_toks);
// FindOrAddToken either locates a token in hash of toks_, or if necessary
// inserts a new, empty token (i.e. with no forward links) for the current
// frame. [note: it's inserted if necessary into hash toks_ and also into the
// singly linked list of tokens active on this frame (whose head is at
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, Token *backpointer,
bool *changed);
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
// delta is the amount by which the extra_costs must change
// before we set *extra_costs_changed = true.
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
bool *links_pruned,
BaseFloat delta);
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, for all Tokens
// that correspond to states that have final-probs. This map will be
// empty if there were no final-probs. It outputs to
// final_relative_cost, if non-NULL, the difference between the best
// forward-cost including the final-prob cost, and the best forward-cost
// without including the final-prob cost (this will usually be positive), or
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
// outputs this quanitity]. It outputs to final_best_cost, if
// non-NULL, the lowest for any token t active on the final frame, of
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
// the graph of the state corresponding to token t, or the best of
// forward-cost[t] if there were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void ComputeFinalCosts(unordered_map<Token*, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const;
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
void PruneForwardLinksFinal();
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
void PruneTokensForFrame(int32 frame_plus_one);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void PruneActiveTokens(BaseFloat delta);
/// Gets the weight cutoff. Also counts the active tokens.
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem);
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
/// use.
BaseFloat ProcessEmitting(DecodableInterface *decodable);
/// Processes nonemitting (epsilon) arcs for one frame. Called after
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
/// preceding ProcessEmitting().
void ProcessNonemitting(BaseFloat cost_cutoff);
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
// more than one list (e.g. for current and previous frames), but only one of
// them at a time can be indexed by StateId. It is indexed by frame-index
// plus one, where the frame-index is zero-based, as used in decodable object.
// That is, the emitting probs of frame t are accounted for in tokens at
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
// the graph.
HashList<StateId, Token*> toks_;
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std::vector<const Elem* > queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// fst_ is a pointer to the FST we are decoding from.
const FST *fst_;
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
// object is destroyed.
bool delete_fst_;
std::vector<BaseFloat> cost_offsets_; // This contains, for each
// frame, an offset that was added to the acoustic log-likelihoods on that
// frame in order to keep everything in a nice dynamic range i.e. close to
// zero, to reduce roundoff errors.
LatticeFasterDecoderConfig config_;
int32 num_toks_; // current total #toks allocated...
bool warned_;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from toks_
/// to avoid having dangling pointers hanging around.
bool decoding_finalized_;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map<Token*, BaseFloat> final_costs_;
BaseFloat final_relative_cost_;
BaseFloat final_best_cost_;
// Memory pools for storing tokens and forward links.
// We use it to decrease the work put on allocator and to move some of data
// together. Too small block sizes will result in more work to allocator but
// bigger ones increase the memory usage.
fst::MemoryPool<Token> token_pool_;
fst::MemoryPool<ForwardLinkT> forward_link_pool_;
// There are various cleanup tasks... the toks_ structure contains
// singly linked lists of Token pointers, where Elem is the list type.
// It also indexes them in a hash, indexed by state (this hash is only
// maintained for the most recent frame). toks_.Clear()
// deletes them from the hash and returns the list of Elems. The
// function DeleteElems calls toks_.Delete(elem) for each elem in
// the list, which returns ownership of the Elem to the toks_ structure
// for reuse, but does not delete the Token pointer. The Token pointers
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
// but are also linked together on each frame by their own linked-list,
// using the "next" pointer. We delete them manually.
void DeleteElems(Elem *list);
// This function takes a singly linked list of tokens for a single frame, and
// outputs a list of them in topological order (it will crash if no such order
// can be found, which will typically be due to decoding graphs with epsilon
// cycles, which are not allowed). Note: the output list may contain NULLs,
// which the caller should pass over; it just happens to be more efficient for
// the algorithm to output a list that contains NULLs.
static void TopSortTokens(Token *tok_list,
std::vector<Token*> *topsorted_list);
void ClearActiveTokens();
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
};
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken> LatticeFasterDecoder;
} // end namespace kaldi.
#endif

@ -0,0 +1,285 @@
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include "decoder/lattice-faster-online-decoder.h"
#include "lat/lattice-functions.h"
namespace kaldi {
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
bool use_final_probs) const {
Lattice lat1;
{
Lattice raw_lat;
this->GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, &lat1);
}
Lattice lat2;
GetBestPath(&lat2, use_final_probs);
BaseFloat delta = 0.1;
int32 num_paths = 1;
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
KALDI_WARN << "Best-path test failed";
return false;
} else {
return true;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(Lattice *olat,
bool use_final_probs) const {
olat->DeleteStates();
BaseFloat final_graph_cost;
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
if (iter.Done())
return false; // would have printed warning.
StateId state = olat->AddState();
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
while (!iter.Done()) {
LatticeArc arc;
iter = TraceBackBestPath(iter, &arc);
arc.nextstate = state;
StateId new_state = olat->AddState();
olat->AddArc(new_state, arc);
state = new_state;
}
olat->SetStart(state);
return true;
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
bool use_final_probs,
BaseFloat *final_cost_out) const {
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "BestPathEnd() with use_final_probs == false";
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
"You cannot call BestPathEnd if no frames were decoded.");
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ :final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_final_cost = 0;
Token *best_tok = NULL;
for (Token *tok = this->active_toks_.back().toks;
tok != NULL; tok = tok->next) {
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
if (use_final_probs && !final_costs.empty()) {
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename unordered_map<Token*, BaseFloat>::const_iterator
iter = final_costs.find(tok);
if (iter != final_costs.end()) {
final_cost = iter->second;
cost += final_cost;
} else {
cost = std::numeric_limits<BaseFloat>::infinity();
}
}
if (cost < best_cost) {
best_cost = cost;
best_tok = tok;
best_final_cost = final_cost;
}
}
if (best_tok == NULL) { // this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN << "No final token found.";
}
if (final_cost_out)
*final_cost_out = best_final_cost;
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(
BestPathIterator iter, LatticeArc *oarc) const {
KALDI_ASSERT(!iter.Done() && oarc != NULL);
Token *tok = static_cast<Token*>(iter.tok);
int32 cur_t = iter.frame, step_t = 0;
if (tok->backpointer != NULL) {
// retrieve the correct forward link(with the best link cost)
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
ForwardLinkT *link;
for (link = tok->backpointer->links;
link != NULL; link = link->next) {
if (link->next_tok == tok) { // this is a link to "tok"
BaseFloat graph_cost = link->graph_cost,
acoustic_cost = link->acoustic_cost;
BaseFloat cost = graph_cost + acoustic_cost;
if (cost < best_cost) {
oarc->ilabel = link->ilabel;
oarc->olabel = link->olabel;
if (link->ilabel != 0) {
KALDI_ASSERT(static_cast<size_t>(cur_t) < this->cost_offsets_.size());
acoustic_cost -= this->cost_offsets_[cur_t];
step_t = -1;
} else {
step_t = 0;
}
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
best_cost = cost;
}
}
}
if (link == NULL &&
best_cost == std::numeric_limits<BaseFloat>::infinity()) { // Did not find correct link.
KALDI_ERR << "Error tracing best-path back (likely "
<< "bug in token-pruning algorithm)";
}
} else {
oarc->ilabel = 0;
oarc->olabel = 0;
oarc->weight = LatticeWeight::One(); // zero costs.
}
return BestPathIterator(tok->backpointer, cur_t + step_t);
}
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token*, BaseFloat> final_costs_local;
const unordered_map<Token*, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = this->active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
for (int32 f = 0; f <= num_frames; f++) {
if (this->active_toks_[f].toks == NULL) {
KALDI_WARN << "No tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
}
unordered_map<Token*, StateId> tok_map;
std::queue<std::pair<Token*, int32> > tok_queue;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for (Token *tok = this->active_toks_[0].toks;
tok != NULL; tok = tok->next) {
if (tok->next == NULL) {
tok_map[tok] = ofst->AddState();
ofst->SetStart(tok_map[tok]);
std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
tok_queue.push(tok_pair);
}
}
// Next create states for "good" tokens
while (!tok_queue.empty()) {
std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
tok_queue.pop();
Token *cur_tok = cur_tok_pair.first;
int32 cur_frame = cur_tok_pair.second;
KALDI_ASSERT(cur_frame >= 0 &&
cur_frame <= this->cost_offsets_.size());
typename unordered_map<Token*, StateId>::const_iterator iter =
tok_map.find(cur_tok);
KALDI_ASSERT(iter != tok_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = cur_tok->links;
l != NULL;
l = l->next) {
Token *next_tok = l->next_tok;
if (next_tok->extra_cost < beam) {
// so both the current and the next token are good; create the arc
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
StateId nextstate;
if (tok_map.find(next_tok) == tok_map.end()) {
nextstate = tok_map[next_tok] = ofst->AddState();
tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
} else {
nextstate = tok_map[next_tok];
}
BaseFloat cost_offset = (l->ilabel != 0 ?
this->cost_offsets_[cur_frame] : 0);
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
}
if (cur_frame == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token*, BaseFloat>::const_iterator iter =
final_costs.find(cur_tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
return (ofst->NumStates() != 0);
}
// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
} // end namespace kaldi.

@ -0,0 +1,147 @@
// decoder/lattice-faster-online-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.h, about how to maintain this
// file in sync with lattice-faster-decoder.h
#ifndef KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_ONLINE_DECODER_H_
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "decoder/lattice-faster-decoder.h"
namespace kaldi {
/** LatticeFasterOnlineDecoderTpl is as LatticeFasterDecoderTpl but also
supports an efficient way to get the best path (see the function
BestPathEnd()), which is useful in endpointing and in situations where you
might want to frequently access the best path.
This is only templated on the FST type, since the Token type is required to
be BackpointerToken. Actually it only makes sense to instantiate
LatticeFasterDecoderTpl with Token == BackpointerToken if you do so indirectly via
this child class.
*/
template <typename FST>
class LatticeFasterOnlineDecoderTpl:
public LatticeFasterDecoderTpl<FST, decoder::BackpointerToken> {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Token = decoder::BackpointerToken;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterOnlineDecoderTpl(const FST &fst,
const LatticeFasterDecoderConfig &config):
LatticeFasterDecoderTpl<FST, Token>(fst, config) { }
// This version of the initializer takes ownership of 'fst', and will delete
// it when this object is destroyed.
LatticeFasterOnlineDecoderTpl(const LatticeFasterDecoderConfig &config,
FST *fst):
LatticeFasterDecoderTpl<FST, Token>(config, fst) { }
struct BestPathIterator {
void *tok;
int32 frame;
// note, "frame" is the frame-index of the frame you'll get the
// transition-id for next time, if you call TraceBackBestPath on this
// iterator (assuming it's not an epsilon transition). Note that this
// is one less than you might reasonably expect, e.g. it's -1 for
// the nonemitting transitions before the first frame.
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
bool Done() const { return tok == NULL; }
};
/// Outputs an FST corresponding to the single best path through the lattice.
/// This is quite efficient because it doesn't get the entire raw lattice and find
/// the best path through it; instead, it uses the BestPathEnd and BestPathIterator
/// so it basically traces it back through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
bool GetBestPath(Lattice *ofst,
bool use_final_probs = true) const;
/// This function does a self-test of GetBestPath(). Returns true on
/// success; returns false and prints a warning on failure.
bool TestGetBestPath(bool use_final_probs = true) const;
/// This function returns an iterator that can be used to trace back
/// the best path. If use_final_probs == true and at least one final state
/// survived till the end, it will use the final-probs in working out the best
/// final Token, and will output the final cost to *final_cost (if non-NULL),
/// else it will use only the forward likelihood, and will put zero in
/// *final_cost (if non-NULL).
/// Requires that NumFramesDecoded() > 0.
BestPathIterator BestPathEnd(bool use_final_probs,
BaseFloat *final_cost = NULL) const;
/// This function can be used in conjunction with BestPathEnd() to trace back
/// the best path one link at a time (e.g. this can be useful in endpoint
/// detection). By "link" we mean a link in the graph; not all links cross
/// frame boundaries, but each time you see a nonzero ilabel you can interpret
/// that as a frame. The return value is the updated iterator. It outputs
/// the ilabel and olabel, and the (graph and acoustic) weight to the "arc" pointer,
/// while leaving its "nextstate" variable unchanged.
BestPathIterator TraceBackBestPath(
BestPathIterator iter, LatticeArc *arc) const;
/// Behaves the same as GetRawLattice but only processes tokens whose
/// extra_cost is smaller than the best-cost plus the specified beam.
/// It is only worthwhile to call this function if beam is less than
/// the lattice_beam specified in the config; otherwise, it would
/// return essentially the same thing as GetRawLattice, but more slowly.
bool GetRawLatticePruned(Lattice *ofst,
bool use_final_probs,
BaseFloat beam) const;
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoderTpl);
};
typedef LatticeFasterOnlineDecoderTpl<fst::StdFst> LatticeFasterOnlineDecoder;
} // end namespace kaldi.
#endif

@ -0,0 +1,147 @@
// lat/determinize-lattice-pruned-test.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/determinize-lattice-pruned.h"
#include "fstext/lattice-utils.h"
#include "fstext/fst-test-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
namespace fst {
// Caution: these tests are not as generic as you might think from all the
// templates in the code. They are basically only valid for LatticeArc.
// This is partly due to the fact that certain templates need to be instantiated
// in other .cc files in this directory.
// test that determinization proceeds correctly on general
// FSTs (not guaranteed determinzable, but we use the
// max-states option to stop it getting out of control).
template<class Arc> void TestDeterminizeLatticePruned() {
typedef kaldi::int32 Int;
typedef typename Arc::Weight Weight;
typedef ArcTpl<CompactLatticeWeightTpl<Weight, Int> > CompactArc;
for(int i = 0; i < 100; i++) {
RandFstOptions opts;
opts.n_states = 4;
opts.n_arcs = 10;
opts.n_final = 2;
opts.allow_empty = false;
opts.weight_multiplier = 0.5; // impt for the randomly generated weights
opts.acyclic = true;
// to be exactly representable in float,
// or this test fails because numerical differences can cause symmetry in
// weights to be broken, which causes the wrong path to be chosen as far
// as the string part is concerned.
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
bool sorted = TopSort(fst);
KALDI_ASSERT(sorted);
ILabelCompare<Arc> ilabel_comp;
if (kaldi::Rand() % 2 == 0)
ArcSort(fst, ilabel_comp);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> det_fst;
try {
DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = ((kaldi::Rand() % 2 == 0) ? 100 : 1000);
lat_opts.max_states = ((kaldi::Rand() % 2 == 0) ? -1 : 20);
lat_opts.max_arcs = ((kaldi::Rand() % 2 == 0) ? -1 : 30);
bool ans = DeterminizeLatticePruned<Weight>(*fst, 10.0, &det_fst, lat_opts);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
KALDI_ASSERT(det_fst.Properties(kIDeterministic, true) & kIDeterministic);
// OK, now determinize it a different way and check equivalence.
// [note: it's not normal determinization, it's taking the best path
// for any input-symbol sequence....
VectorFst<Arc> pruned_fst(*fst);
if (pruned_fst.NumStates() != 0)
kaldi::PruneLattice(10.0, &pruned_fst);
VectorFst<CompactArc> compact_pruned_fst, compact_pruned_det_fst;
ConvertLattice<Weight, Int>(pruned_fst, &compact_pruned_fst, false);
std::cout << "Compact pruned FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
ConvertLattice<Weight, Int>(det_fst, &compact_pruned_det_fst, false);
std::cout << "Compact version of determinized FST is:\n";
{
FstPrinter<CompactArc> fstprinter(compact_pruned_det_fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
if (ans)
KALDI_ASSERT(RandEquivalent(compact_pruned_det_fst, compact_pruned_fst, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length, max*/));
} catch (...) {
std::cout << "Failed to lattice-determinize this FST (probably not determinizable)\n";
}
delete fst;
}
}
// test that determinization proceeds without crash on acyclic FSTs
// (guaranteed determinizable in this sense).
template<class Arc> void TestDeterminizeLatticePruned2() {
typedef typename Arc::Weight Weight;
RandFstOptions opts;
opts.acyclic = true;
for(int i = 0; i < 100; i++) {
VectorFst<Arc> *fst = RandPairFst<Arc>(opts);
std::cout << "FST before lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
VectorFst<Arc> ofst;
DeterminizeLatticePruned<Weight>(*fst, 10.0, &ofst);
std::cout << "FST after lattice-determinizing is:\n";
{
FstPrinter<Arc> fstprinter(ofst, NULL, NULL, NULL, false, true, "\t");
fstprinter.Print(&std::cout, "standard output");
}
delete fst;
}
}
} // end namespace fst
int main() {
using namespace fst;
TestDeterminizeLatticePruned<kaldi::LatticeArc>();
TestDeterminizeLatticePruned2<kaldi::LatticeArc>();
std::cout << "Tests succeeded\n";
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,296 @@
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
#include "itf/transition-information.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct DeterminizeLatticePrunedOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int max_states;
int max_arcs;
float retry_cutoff;
DeterminizeLatticePrunedOptions(): delta(kDelta),
max_mem(-1),
max_loop(-1),
max_states(-1),
max_arcs(-1),
retry_cutoff(0.5) { }
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this)");
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
"output FST (total, not per state");
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
"FST (total, not per state");
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)");
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments.");
}
};
struct DeterminizeLatticePhonePrunedOptions {
// delta: a small offset used to measure equality of weights.
float delta;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int max_mem;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool phone_determinize;
// word_determinize: if true, do a second pass determinization on words only.
bool word_determinize;
// minimize: if true, push and minimize after determinization.
bool minimize;
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
max_mem(50000000),
phone_determinize(true),
word_determinize(true),
minimize(false) {}
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this).");
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)");
opts->Register("word-determinize", &word_determinize, "If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)");
opts->Register("minimize", &minimize, "If true, push and minimize after "
"determinization.");
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template<class Weight>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/** This function takes in lattices and inserts phones at phone boundaries. It
uses the transition model to work out the transition_id to phone map. The
returning value is the starting index of the phone label. Typically we pick
(maximum_output_label_index + 1) as this value. The inserted phones are then
mapped to (returning_value + original_phone_label) in the new lattice. The
returning value will be used by DeterminizeLatticeDeletePhones() where it
works out the phones according to this value.
*/
template<class Weight>
typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *fst);
/** This function takes in lattices and deletes "phones" from them. The "phones"
here are actually any label that is larger than first_phone_label because
when we insert phones into the lattice, we map the original phone label to
(first_phone_label + original_phone_label). It is supposed to be used
together with DeterminizeLatticeInsertPhones()
*/
template<class Weight>
void DeterminizeLatticeDeletePhones(
typename ArcTpl<Weight>::Label first_phone_label,
MutableFst<ArcTpl<Weight> > *fst);
/** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
determinization on the phone + word lattices. If --word-determinize is set
true, it then does a second pass of determinization on the word lattices by
calling DeterminizeLatticePruned(). If both are set to false, then it gives
a warning and copying the lattices without determinization.
Note: the point of doing first a phone-level determinization pass and then
a word-level determinization pass is that it allows us to determinize
deeper lattices without "failing early" and returning a too-small lattice
due to the max-mem constraint. The result should be the same as word-level
determinization in general, but for deeper lattices it is a bit faster,
despite the fact that we now have two passes of determinization by default.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** "Destructive" version of DeterminizeLatticePhonePruned() where the input
lattice might be changed.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePhonePruned(
const kaldi::TransitionInformation &trans_model,
MutableFst<ArcTpl<Weight> > *ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
Lattice type FSTs. It simplifies the calling process by calling
TopSort() Invert() and ArcSort() for you.
Unlike other determinization routines, the function
requires "ifst" to have transition-id's on the input side and words on the
output side.
This function can be used as the top-level interface to all the determinization
code.
*/
bool DeterminizeLatticePhonePrunedWrapper(
const kaldi::TransitionInformation &trans_model,
MutableFst<kaldi::LatticeArc> *ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePhonePrunedOptions opts
= DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#endif

@ -0,0 +1,506 @@
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace kaldi {
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
CompactLattice *ofst = new CompactLattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
return ifst;
}
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
Lattice *ofst = new Lattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
Lattice* ConvertToLattice(Lattice *ifst) {
return ifst;
}
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = true, write_one = false;
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class LatticeReader {
typedef LatticeArc Arc;
typedef LatticeWeight Weight;
typedef CompactLatticeArc CArc;
typedef CompactLatticeWeight CWeight;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static std::pair<Lattice*, CompactLattice*> ReadText(
std::istream &is) {
typedef std::pair<Lattice*, CompactLattice*> PairT;
using std::string;
using std::vector;
Lattice *fst = new Lattice();
CompactLattice *cfst = new CompactLattice();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
if (fst)
while (s >= fst->NumStates())
fst->AddState();
if (cfst)
while (s >= cfst->NumStates())
cfst->AddState();
if (nline == 1) {
if (fst) fst->SetStart(s);
if (cfst) cfst->SetStart(s);
}
if (fst) { // we still have fst; try to read that arc.
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1 :
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w)) ok = false;
else fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates())
fst->AddState();
if (!ok) {
delete fst;
fst = NULL;
}
}
if (cfst) {
bool ok = true;
CArc arc;
CWeight w;
StateId d = s;
switch (col.size()) {
case 1 :
cfst->SetFinal(s, CWeight::One());
break;
case 2:
if (!StrToCWeight(col[1], true, &w)) ok = false;
else cfst->SetFinal(s, w);
break;
case 3: // compact-lattice is acceptor format: state, next-state, label.
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
arc.weight = CWeight::One();
cfst->AddArc(s, arc);
}
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
StrToCWeight(col[3], false, &arc.weight);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
cfst->AddArc(s, arc);
}
break;
case 5: default:
ok = false;
}
while (d >= cfst->NumStates())
cfst->AddState();
if (!ok) {
delete cfst;
cfst = NULL;
}
}
if (!fst && !cfst) {
KALDI_WARN << "Bad line in lattice text format: " << line;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while (std::getline(is, line)) {
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.empty()) break;
}
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
}
return PairT(fst, cfst);
}
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == Weight::Zero())) {
return false;
}
return true;
}
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
return false;
}
return true;
}
};
CompactLattice *ReadCompactLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.second != NULL) {
delete lat_pair.first;
return lat_pair.second;
} else if (lat_pair.first != NULL) {
// note: ConvertToCompactLattice frees its input.
return ConvertToCompactLattice(lat_pair.first);
} else {
return NULL;
}
}
Lattice *ReadLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.first != NULL) {
delete lat_pair.second;
return lat_pair.first;
} else if (lat_pair.second != NULL) {
// note: ConvertToLattice frees its input.
return ConvertToLattice(lat_pair.second);
} else {
return NULL;
}
}
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat) {
KALDI_ASSERT(*clat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading compact lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
CompactLattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToCompactLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToCompactLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToCompactLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToCompactLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to CompactLattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading compact lattice (after reading header).";
return false;
}
*clat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
return (*clat != NULL);
}
}
bool CompactLatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading CompactLattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadCompactLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadCompactLattice(is, true, &t_);
}
}
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = false, write_one = false;
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat) {
KALDI_ASSERT(*lat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
Lattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to Lattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading lattice (after reading header).";
return false;
}
*lat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*lat = ReadLatticeText(is); // that routine will warn on error.
return (*lat != NULL);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool LatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Lattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadLattice(is, true, &t_);
}
}
} // end namespace kaldi

@ -0,0 +1,156 @@
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
namespace kaldi {
// will import some things above...
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
CompactLatticeWeightCommonDivisor;
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
typedef fst::VectorFst<LatticeArc> Lattice;
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &clat);
bool WriteLattice(std::ostream &os, bool binary,
const Lattice &lat);
// the following function requires that *clat be
// NULL when called.
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat);
// the following function requires that *lat be
// NULL when called.
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat);
class CompactLatticeHolder {
public:
typedef CompactLattice T;
CompactLatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteCompactLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(CompactLatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~CompactLatticeHolder() { Clear(); }
private:
T *t_;
};
class LatticeHolder {
public:
typedef Lattice T;
LatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(LatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~LatticeHolder() { Clear(); }
private:
T *t_;
};
typedef TableWriter<LatticeHolder> LatticeWriter;
typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_

File diff suppressed because it is too large Load Diff

@ -0,0 +1,402 @@
// lat/lattice-functions.h
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
// Bagher BabaAli
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "itf/transition-information.h"
#include "lat/kaldi-lattice.h"
namespace kaldi {
// Redundant with the typedef in hmm/posterior.h. We want functions
// using the Posterior type to be usable without a dependency on the
// hmm library.
typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
/**
This function extracts the per-frame log likelihoods from a linear
lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
The dimension of *per_frame_loglikes will be set to the
number of input symbols in 'nbest'. The elements of
'*per_frame_loglikes' will be set to the .Value2() elements of the lattice
weights, which represent the acoustic costs; you may want to scale this
vector afterward by -1/acoustic_scale to get the original loglikes.
If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
(and this should not normally be the case in situations where it makes
sense to call this function), they will be included to the cost of the
preceding input symbol, or the following input symbol for input-epsilons
encountered prior to any input symbol. If 'nbest' has no input symbols,
'per_frame_loglikes' will be set to the empty vector.
**/
void GetPerFrameAcousticCosts(const Lattice &nbest,
Vector<BaseFloat> *per_frame_loglikes);
/// This function iterates over the states of a topologically sorted lattice and
/// counts the time instance corresponding to each state. The times are returned
/// in a vector of integers 'times' which is resized to have a size equal to the
/// number of states in the lattice. The function also returns the maximum time
/// in the lattice (this will equal the number of frames in the file).
int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
/// As LatticeStateTimes, but in the CompactLattice format. Note: must
/// be topologically sorted. Returns length of the utterance in frames, which
/// might not be the same as the maximum time in the lattice, due to frames
/// in the final-prob.
int32 CompactLatticeStateTimes(const CompactLattice &clat,
std::vector<int32> *times);
/// This function does the forward-backward over lattices and computes the
/// posterior probabilities of the arcs. It returns the total log-probability
/// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
/// on each frame.
/// If the pointer "acoustic_like_sum" is provided, this value is set to
/// the sum over the arcs, of the posterior of the arc times the
/// acoustic likelihood [i.e. negated acoustic score] on that link.
/// This is used in combination with other quantities to work out
/// the objective function in MMI discriminative training.
BaseFloat LatticeForwardBackward(const Lattice &lat,
Posterior *arc_post,
double *acoustic_like_sum = NULL);
// This function is something similar to LatticeForwardBackward(), but it is on
// the CompactLattice lattice format. Also we only need the alpha in the forward
// path, not the posteriors.
bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
std::vector<double> *alpha);
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
// the backward path here.
bool ComputeCompactLatticeBetas(const CompactLattice &lat,
std::vector<double> *beta);
// Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
// best-path negated cost) Note: in either case, the alphas and betas are
// negated costs. Requires that lat be topologically sorted. This code
// will work for either CompactLattice or Lattice.
template<typename LatticeType>
double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
bool viterbi,
std::vector<double> *alpha,
std::vector<double> *beta);
/// Topologically sort the compact lattice if not already topologically sorted.
/// Will crash if the lattice cannot be topologically sorted.
void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
/// Topologically sort the lattice if not already topologically sorted.
/// Will crash if lattice cannot be topologically sorted.
void TopSortLatticeIfNeeded(Lattice *clat);
/// Returns the depth of the lattice, defined as the average number of arcs (or
/// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
/// Requires that clat is topologically sorted!
BaseFloat CompactLatticeDepth(const CompactLattice &clat,
int32 *num_frames = NULL);
/// This function returns, for each frame, the number of arcs crossing that
/// frame.
void CompactLatticeDepthPerFrame(const CompactLattice &clat,
std::vector<int32> *depth_per_frame);
/// This function limits the depth of the lattice, per frame: that means, it
/// does not allow more than a specified number of arcs active on any given
/// frame. This can be used to reduce the size of the "very deep" portions of
/// the lattice.
void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
CompactLattice *clat);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// outputs for each frame the set of phones active on that frame. If
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
/// phones in this list.
void LatticeActivePhones(const Lattice &lat, const TransitionInformation &trans,
const std::vector<int32> &sil_phones,
std::vector<std::set<int32> > *active_phones);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the output symbols (presumably words), with phones; we
/// use the TransitionModel to work out the phone sequence. Note
/// that the phone labels are not exactly aligned with the phone
/// boundaries. We put a phone label to coincide with any transition
/// to the final, nonemitting state of a phone (this state always exists,
/// we ensure this in HmmTopology::Check()). This would be the last
/// transition-id in the phone if reordering is not done (but typically
/// we do reorder).
/// Also see PhoneAlignLattice, in phone-align-lattice.h.
void ConvertLatticeToPhones(const TransitionInformation &trans_model,
Lattice *lat);
/// Prunes a lattice or compact lattice. Returns true on success, false if
/// there was some kind of failure.
template<class LatticeType>
bool PruneLattice(BaseFloat beam, LatticeType *lat);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// replace the sequences of transition-ids with sequences of phones.
/// Note that this is different from ConvertLatticeToPhones, in that
/// we replace the transition-ids not the words.
void ConvertCompactLatticeToPhones(const TransitionInformation &trans_model,
CompactLattice *clat);
/// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
/// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
/// There is a frame error if a particular transition-id on a particular frame
/// corresponds to a phone not matching transcription's alignment for that frame.
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
/// The TransitionInformation is used to map transition-ids in the lattice
/// input-side to phones; the phones appearing in
/// "silence_phones" are treated specially in that we replace the frame error f
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
/// For the normal recipe, max_silence_error would be zero.
/// Returns true on success, false if there was some kind of mismatch.
/// At input, silence_phones must be sorted and unique.
bool LatticeBoost(const TransitionInformation &trans,
const std::vector<int32> &alignment,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat);
/**
This function implements either the MPFE (minimum phone frame error) or SMBR
(state-level minimum bayes risk) forward-backward, depending on whether
"criterion" is "mpfe" or "smbr". It returns the MPFE
criterion of SMBR criterion for this utterance, and outputs the posteriors (which
may be positive or negative) into "post".
@param [in] trans The transition model. Used to map the
transition-ids to phones or pdfs.
@param [in] silence_phones A list of integer ids of silence phones. The
silence frames i.e. the frames where num_ali
corresponds to a silence phones are treated specially.
The behavior is determined by 'one_silence_class'
being false (traditional behavior) or true.
Usually in our setup, several phones including
the silence, vocalized noise, non-spoken noise
and unk are treated as "silence phones"
@param [in] lat The denominator lattice
@param [in] num_ali The numerator alignment
@param [in] criterion The objective function. Must be "mpfe" or "smbr"
for MPFE (minimum phone frame error) or sMBR
(state minimum bayes risk) training.
@param [in] one_silence_class Determines how the silence frames are treated.
Setting this to false gives the old traditional behavior,
where the silence frames (according to num_ali) are
treated as incorrect. However, this means that the
insertions are not penalized by the objective.
Setting this to true gives the new behaviour, where we
treat silence as any other phone, except that all pdfs
of silence phones are collapsed into a single class for
the frame-error computation. This can possible reduce
the insertions in the trained model. This is closer to
the WER metric that we actually care about, since WER is
generally computed after filtering out noises, but
does penalize insertions.
@param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
pseudo log-likelihoods of states at each frame.
*/
BaseFloat LatticeForwardBackwardMpeVariants(
const TransitionInformation &trans,
const std::vector<int32> &silence_phones,
const Lattice &lat,
const std::vector<int32> &num_ali,
std::string criterion,
bool one_silence_class,
Posterior *post);
/// This function takes a CompactLattice that should only contain a single
/// linear sequence (e.g. derived from lattice-1best), and that should have been
/// processed so that the arcs in the CompactLattice align correctly with the
/// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
/// same size, which give, for each word in the lattice (in sequence), the word
/// label and the begin time and length in frames. This is done even for zero
/// (epsilon) words, generally corresponding to optional silence-- if you don't
/// want them, just ignore them in the output.
/// This function will print a warning and return false, if the lattice
/// did not have the correct format (e.g. if it is empty or it is not
/// linear).
bool CompactLatticeToWordAlignment(const CompactLattice &clat,
std::vector<int32> *words,
std::vector<int32> *begin_times,
std::vector<int32> *lengths);
/// A form of the shortest-path/best-path algorithm that's specially coded for
/// CompactLattice. Requires that clat be acyclic.
void CompactLatticeShortestPath(const CompactLattice &clat,
CompactLattice *shortest_path);
/// This function expands a CompactLattice to ensure high-probability paths
/// have unique histories. Arcs with posteriors larger than epsilon get splitted.
void ExpandCompactLattice(const CompactLattice &clat,
double epsilon,
CompactLattice *expand_clat);
/// For each state, compute forward and backward best (viterbi) costs and its
/// traceback states (for generating best paths later). The forward best cost
/// for a state is the cost of the best path from the start state to the state.
/// The traceback state of this state is its predecessor state in the best path.
/// The backward best cost for a state is the cost of the best path from the
/// state to a final one. Its traceback state is the successor state in the best
/// path in the forward direction.
/// Note: final weights of states are in backward_best_cost_and_pred.
/// Requires the input CompactLattice clat be acyclic.
typedef std::vector<std::pair<double,
CompactLatticeArc::StateId> > CostTraceType;
void CompactLatticeBestCostsAndTracebacks(
const CompactLattice &clat,
CostTraceType *forward_best_cost_and_pred,
CostTraceType *backward_best_cost_and_pred);
/// This function adds estimated neural language model scores of words in a
/// minimal list of hypotheses that covers a lattice, to the graph scores on the
/// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
CompactLattice *clat);
/// This function add the word insertion penalty to graph score of each word
/// in the compact lattice
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
CompactLattice *clat);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (typically some kind of mismatched inputs).
bool RescoreCompactLattice(DecodableInterface *decodable,
CompactLattice *clat);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice (i.e. the the maximum of any path, of the count of
/// olabels on that path).
int32 LongestSentenceLength(const Lattice &lat);
/// This function returns the number of words in the longest sentence in a
/// CompactLattice, i.e. the the maximum of any path, of the count of
/// labels on that path... note, in CompactLattice, the ilabels and olabels
/// are identical because it is an acceptor.
int32 LongestSentenceLength(const CompactLattice &lat);
/// This function is like RescoreCompactLattice, but it is modified to avoid
/// computing probabilities on most frames where all the pdf-ids are the same.
/// (it needs the transition-model to work out whether two transition-ids map to
/// the same pdf-id, and it assumes that the lattice has transition-ids on it).
/// The naive thing would be to just set all probabilities to zero on frames
/// where all the pdf-ids are the same (because this value won't affect the
/// lattice posterior). But this would become confusing when we compute
/// corpus-level diagnostics such as the MMI objective function. Instead,
/// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
/// speedup_factor) we compute those likelihoods and multiply them by
/// speedup_factor; otherwise we set them to zero. This gives the right
/// expected probability so our corpus-level diagnostics will be about right.
bool RescoreCompactLatticeSpeedup(
const TransitionInformation &tmodel,
BaseFloat speedup_factor,
DecodableInterface *decodable,
CompactLattice *clat);
/// This function *adds* the negated scores obtained from the Decodable object,
/// to the acoustic scores on the arcs. If you want to replace them, you should
/// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
/// true on success, false on error (e.g. some kind of mismatched inputs).
/// The input labels, if nonzero, are interpreted as transition-ids or whatever
/// other index the Decodable object expects.
bool RescoreLattice(DecodableInterface *decodable,
Lattice *lat);
/// This function Composes a CompactLattice format lattice with a
/// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
/// CompactLattice format lattice. The first element (the one that corresponds
/// to LM weight) in CompactLatticeWeight is used for composition.
///
/// Note that the DeterministicOnDemandFst interface is not "const", therefore
/// we cannot use "const" for <det_fst>.
void ComposeCompactLatticeDeterministic(
const CompactLattice& clat,
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
CompactLattice* composed_clat);
/// This function computes the mapping from the pair
/// (frame-index, transition-id) to the pair
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
/// transition-id in that frame.
/// frame-index in the lattice.
/// This function is useful for retaining the acoustic scores in a
/// non-compact lattice after a process like determinization where the
/// frame-level acoustic scores are typically lost.
/// The function ReplaceAcousticScoresFromMap is used to restore the
/// acoustic scores computed by this function.
///
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
/// function will crash.
/// @param [out] acoustic_scores
/// Pointer to a map from the pair (frame-index,
/// transition-id) to a pair (sum-of-acoustic-scores,
/// num-of-occurences).
/// Usually the acoustic scores for a pdf-id (and hence
/// transition-id) on a frame will be the same for all the
/// occurences of the pdf-id in that frame.
/// But if not, we will take the average of the acoustic
/// scores. Hence, we store both the sum-of-acoustic-scores
/// and the num-of-occurences of the transition-id in that
/// frame.
void ComputeAcousticScoresMap(
const Lattice &lat,
unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > *acoustic_scores);
/// This function restores acoustic scores computed using the function
/// ComputeAcousticScoresMap into the lattice.
///
/// @param [in] acoustic_scores
/// A map from the pair (frame-index, transition-id) to a
/// pair (sum-of-acoustic-scores, num-of-occurences) of
/// the occurences of the transition-id in that frame.
/// See the comments for ComputeAcousticScoresMap for
/// details.
/// @param [out] lat Pointer to the output lattice.
void ReplaceAcousticScoresFromMap(
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > &acoustic_scores,
Lattice *lat);
} // namespace kaldi
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_

@ -0,0 +1,122 @@
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Mirko Hannemann; Go Vivace Inc.;
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ITF_DECODABLE_ITF_H_
#define KALDI_ITF_DECODABLE_ITF_H_ 1
#include "base/kaldi-common.h"
namespace kaldi {
/// @ingroup Interfaces
/// @{
/**
DecodableInterface provides a link between the (acoustic-modeling and
feature-processing) code and the decoder. The idea is to make this
interface as small as possible, and to make it as agnostic as possible about
the form of the acoustic model (e.g. don't assume the probabilities are a
function of just a vector of floats), and about the decoder (e.g. don't
assume it accesses frames in strict left-to-right order). For normal
models, without on-line operation, the "decodable" sub-class will just be a
wrapper around a matrix of features and an acoustic model, and it will
answer the question 'what is the acoustic likelihood for this index and this
frame?'.
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
// Process this frame
}
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
\code{.cc}
while (num_frames_decoded_ < decodable.NumFramesReady()) {
// Decode one more frame [increments num_frames_decoded_]
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
For truly online decoding, the "old" online decodable objects in ../online/
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
The "new" online decodable objects in ../online2/ return the number of frames
currently accessible if you call NumFramesReady(). You will likely not need
to call IsLastFrame(), but we implement it to only return true for the last
frame of the file once we've decided to terminate decoding.
*/
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame
/// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently available
/// for this decodable object. This is for use in setups where you don't want the
/// decoder to block while waiting for input. This is newly added as of Jan 2014,
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
return -1;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual ~DecodableInterface() {}
};
/// @}
} // namespace Kaldi
#endif // KALDI_ITF_DECODABLE_ITF_H_

@ -0,0 +1,18 @@
#include "nnet/decodable-itf.h"
#include "base/common.h"
namespace ppsepeech {
struct DecodeableConfig;
class Decodeable : public kaldi::DecodableInterface {
public:
virtual Init(Decodeable config) = 0;
virtual Acceptlikeihood() = 0;
private:
std::share_ptr<FeatureExtractorInterface> frontend_;
std::share_ptr<NnetInterface> nnet_;
//Cache nnet_cache_;
}
} // namespace ppspeech

@ -8,7 +8,7 @@ namespace ppspeech {
class NnetInterface {
public:
virtual ~NnetForwardInterface() {}
virtual ~NnetInterface() {}
virtual void FeedForward(const kaldi::Matrix<BaseFloat>& features,
kaldi::Matrix<kaldi::BaseFloat>* inferences) const = 0;

Loading…
Cancel
Save