append some comments

pull/2/head
Yibing Liu 7 years ago
parent 8ec4a96523
commit adab01bbf6

@ -14,8 +14,8 @@
#include "path_trie.h" #include "path_trie.h"
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string>& vocabulary) { const std::vector<std::string> &vocabulary) {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
@ -60,7 +60,7 @@ std::string ctc_greedy_decoder(
} }
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>> &probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
@ -104,7 +104,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
if (!extscorer->is_character_based()) { if (!extscorer->is_character_based()) {
if (extscorer->dictionary == nullptr) { if (extscorer->dictionary == nullptr) {
// fill dictionary for fst // fill dictionary for fst with space
extscorer->fill_dictionary(true); extscorer->fill_dictionary(true);
} }
auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary); auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary);
@ -282,9 +282,9 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>>& probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size, int beam_size,
const std::vector<std::string>& vocabulary, const std::vector<std::string> &vocabulary,
int blank_id, int blank_id,
int num_processes, int num_processes,
double cutoff_prob, double cutoff_prob,
@ -304,8 +304,7 @@ ctc_beam_search_decoder_batch(
if (extscorer->is_char_map_empty()) { if (extscorer->is_char_map_empty()) {
extscorer->set_char_map(vocabulary); extscorer->set_char_map(vocabulary);
} }
if (!extscorer->is_character_based() && if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) {
extscorer->dictionary == nullptr) {
// init dictionary // init dictionary
extscorer->fill_dictionary(true); extscorer->fill_dictionary(true);
} }

@ -14,12 +14,11 @@
* over vocabulary of one time step. * over vocabulary of one time step.
* vocabulary: A vector of vocabulary. * vocabulary: A vector of vocabulary.
* Return: * Return:
* A vector that each element is a pair of score and decoding result, * The decoding result in string
* in desending order.
*/ */
std::string ctc_greedy_decoder( std::string ctc_greedy_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string>& vocabulary); const std::vector<std::string> &vocabulary);
/* CTC Beam Search Decoder /* CTC Beam Search Decoder
@ -37,7 +36,7 @@ std::string ctc_greedy_decoder(
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>> &probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
@ -59,14 +58,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* cutoff_top_n: Cutoff number for pruning. * cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix. * ext_scorer: External scorer to evaluate a prefix.
* Return: * Return:
* A 2-D vector that each element is a vector of decoding result for one * A 2-D vector that each element is a vector of beam search decoding
* sample. * result for one audio sample.
*/ */
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>>& probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
int beam_size, int beam_size,
const std::vector<std::string>& vocabulary, const std::vector<std::string> &vocabulary,
int blank_id, int blank_id,
int num_processes, int num_processes,
double cutoff_prob = 1.0, double cutoff_prob = 1.0,

@ -4,7 +4,7 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
size_t get_utf8_str_len(const std::string& str) { size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0; size_t str_len = 0;
for (char c : str) { for (char c : str) {
str_len += ((c & 0xc0) != 0x80); str_len += ((c & 0xc0) != 0x80);
@ -12,7 +12,7 @@ size_t get_utf8_str_len(const std::string& str) {
return str_len; return str_len;
} }
std::vector<std::string> split_utf8_str(const std::string& str) { std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result; std::vector<std::string> result;
std::string out_str; std::string out_str;
@ -31,8 +31,8 @@ std::vector<std::string> split_utf8_str(const std::string& str) {
return result; return result;
} }
std::vector<std::string> split_str(const std::string& s, std::vector<std::string> split_str(const std::string &s,
const std::string& delim) { const std::string &delim) {
std::vector<std::string> result; std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size(); std::size_t start = 0, delim_len = delim.size();
while (true) { while (true) {
@ -51,7 +51,7 @@ std::vector<std::string> split_str(const std::string& s,
return result; return result;
} }
bool prefix_compare(const PathTrie* x, const PathTrie* y) { bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) { if (x->score == y->score) {
if (x->character == y->character) { if (x->character == y->character) {
return false; return false;
@ -63,8 +63,8 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y) {
} }
} }
void add_word_to_fst(const std::vector<int>& word, void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst* dictionary) { fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) { if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState(); fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0); assert(start == 0);
@ -81,16 +81,16 @@ void add_word_to_fst(const std::vector<int>& word,
} }
bool add_word_to_dictionary( bool add_word_to_dictionary(
const std::string& word, const std::string &word,
const std::unordered_map<std::string, int>& char_map, const std::unordered_map<std::string, int> &char_map,
bool add_space, bool add_space,
int SPACE_ID, int SPACE_ID,
fst::StdVectorFst* dictionary) { fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word); auto characters = split_utf8_str(word);
std::vector<int> int_word; std::vector<int> int_word;
for (auto& c : characters) { for (auto &c : characters) {
if (c == " ") { if (c == " ") {
int_word.push_back(SPACE_ID); int_word.push_back(SPACE_ID);
} else { } else {
@ -108,5 +108,5 @@ bool add_word_to_dictionary(
} }
add_word_to_fst(int_word, dictionary); add_word_to_fst(int_word, dictionary);
return true; return true; // return with successful adding
} }

@ -14,12 +14,14 @@ bool pair_comp_first_rev(const std::pair<T1, T2> &a,
return a.first > b.first; return a.first > b.first;
} }
// Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) { const std::pair<T1, T2> &b) {
return a.second > b.second; return a.second > b.second;
} }
// Return the sum of two probabilities in log scale
template <typename T> template <typename T>
T log_sum_exp(const T &x, const T &y) { T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max(); static T num_min = -std::numeric_limits<T>::max();
@ -32,18 +34,21 @@ T log_sum_exp(const T &x, const T &y) {
// Functor for prefix comparsion // Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y); bool prefix_compare(const PathTrie *x, const PathTrie *y);
// Get length of utf8 encoding string /* Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229 * See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str); size_t get_utf8_str_len(const std::string &str);
// Split a string into a list of strings on a given string /* Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are * delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s, std::vector<std::string> split_str(const std::string &s,
const std::string &delim); const std::string &delim);
// Splits string into vector of strings representing /* Splits string into vector of strings representing
// UTF-8 characters (not same as chars) * UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str); std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst // Add a word in index to the dicionary of fst

@ -22,7 +22,7 @@ PathTrie::PathTrie() {
_dictionary = nullptr; _dictionary = nullptr;
_dictionary_state = 0; _dictionary_state = 0;
_has_dictionary = false; _has_dictionary = false;
_matcher = nullptr; // finds arcs in FST _matcher = nullptr;
} }
PathTrie::~PathTrie() { PathTrie::~PathTrie() {

@ -10,27 +10,36 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class PathTrie { class PathTrie {
public: public:
PathTrie(); PathTrie();
~PathTrie(); ~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true); PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output); PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps = std::numeric_limits<size_t>::max()); size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output); void iterate_to_vec(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary); void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<FSTMATCH> matcher); void set_matcher(std::shared_ptr<FSTMATCH> matcher);
bool is_empty() { return _ROOT == character; } bool is_empty() { return _ROOT == character; }
// remove current path from root
void remove(); void remove();
float log_prob_b_prev; float log_prob_b_prev;
@ -49,8 +58,10 @@ private:
std::vector<std::pair<int, PathTrie*>> _children; std::vector<std::pair<int, PathTrie*>> _children;
// pointer to dictionary of FST
fst::StdVectorFst* _dictionary; fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state; fst::StdVectorFst::StateId _dictionary_state;
// true if finding ars in FST
std::shared_ptr<FSTMATCH> _matcher; std::shared_ptr<FSTMATCH> _matcher;
}; };

@ -68,7 +68,7 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
state = out_state; state = out_state;
out_state = tmp_state; out_state = tmp_state;
} }
// log10 prob // return log10 prob
return cond_prob; return cond_prob;
} }
@ -189,23 +189,26 @@ void Scorer::fill_dictionary(bool add_space) {
std::cerr << "Vocab Size " << vocab_size << std::endl; std::cerr << "Vocab Size " << vocab_size << std::endl;
// Simplify FST /* Simplify FST
// This gets rid of "epsilon" transitions in the FST. * This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken. * These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but * Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST * can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary); fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst; fst::StdVectorFst* new_dict = new fst::StdVectorFst;
// This makes the FST deterministic, meaning for any string input there's /* This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our * only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it. * dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state) * (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict); fst::Determinize(dictionary, new_dict);
// Finds the simplest equivalent fst. This is unnecessary but decreases /* Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary * memory usage of the dictionary
*/
fst::Minimize(new_dict); fst::Minimize(new_dict);
this->dictionary = new_dict; this->dictionary = new_dict;
} }

@ -23,14 +23,15 @@ class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece& str) { void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length())); vocabulary.push_back(std::string(str.data(), str.length()));
} }
std::vector<std::string> vocabulary; std::vector<std::string> vocabulary;
}; };
/* External scorer to query languange score for n-gram or sentence. /* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
* *
* Example: * Example:
* Scorer scorer(alpha, beta, "path_of_language_model"); * Scorer scorer(alpha, beta, "path_of_language_model");
@ -39,12 +40,12 @@ public:
*/ */
class Scorer { class Scorer {
public: public:
Scorer(double alpha, double beta, const std::string& lm_path); Scorer(double alpha, double beta, const std::string &lm_path);
~Scorer(); ~Scorer();
double get_log_cond_prob(const std::vector<std::string>& words); double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string>& words); double get_sent_log_prob(const std::vector<std::string> &words);
size_t get_max_order() { return _max_order; } size_t get_max_order() { return _max_order; }
@ -56,32 +57,32 @@ public:
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// make ngram // make ngram
std::vector<std::string> make_ngram(PathTrie* prefix); std::vector<std::string> make_ngram(PathTrie *prefix);
// fill dictionary for fst // fill dictionary for fst
void fill_dictionary(bool add_space); void fill_dictionary(bool add_space);
// set char map // set char map
void set_char_map(const std::vector<std::string>& char_list); void set_char_map(const std::vector<std::string> &char_list);
std::vector<std::string> split_labels(const std::vector<int>& labels); std::vector<std::string> split_labels(const std::vector<int> &labels);
// expose to decoder // expose to decoder
double alpha; double alpha;
double beta; double beta;
// fst dictionary // fst dictionary
void* dictionary; void *dictionary;
protected: protected:
void load_LM(const char* filename); void load_LM(const char *filename);
double get_log_prob(const std::vector<std::string>& words); double get_log_prob(const std::vector<std::string> &words);
std::string vec2str(const std::vector<int>& input); std::string vec2str(const std::vector<int> &input);
private: private:
void* _language_model; void *_language_model;
bool _is_character_based; bool _is_character_based;
size_t _max_order; size_t _max_order;

Loading…
Cancel
Save