diff --git a/models/swig_decoders/ctc_decoders.cpp b/models/swig_decoders/ctc_decoders.cpp index 4c9a45d9..10979912 100644 --- a/models/swig_decoders/ctc_decoders.cpp +++ b/models/swig_decoders/ctc_decoders.cpp @@ -14,8 +14,8 @@ #include "path_trie.h" std::string ctc_greedy_decoder( - const std::vector>& probs_seq, - const std::vector& vocabulary) { + const std::vector> &probs_seq, + const std::vector &vocabulary) { // dimension check int num_time_steps = probs_seq.size(); for (int i = 0; i < num_time_steps; i++) { @@ -60,7 +60,7 @@ std::string ctc_greedy_decoder( } std::vector> ctc_beam_search_decoder( - const std::vector>& probs_seq, + const std::vector> &probs_seq, int beam_size, std::vector vocabulary, int blank_id, @@ -104,7 +104,7 @@ std::vector> ctc_beam_search_decoder( } if (!extscorer->is_character_based()) { if (extscorer->dictionary == nullptr) { - // fill dictionary for fst + // fill dictionary for fst with space extscorer->fill_dictionary(true); } auto fst_dict = static_cast(extscorer->dictionary); @@ -282,9 +282,9 @@ std::vector> ctc_beam_search_decoder( std::vector>> ctc_beam_search_decoder_batch( - const std::vector>>& probs_split, + const std::vector>> &probs_split, int beam_size, - const std::vector& vocabulary, + const std::vector &vocabulary, int blank_id, int num_processes, double cutoff_prob, @@ -304,8 +304,7 @@ ctc_beam_search_decoder_batch( if (extscorer->is_char_map_empty()) { extscorer->set_char_map(vocabulary); } - if (!extscorer->is_character_based() && - extscorer->dictionary == nullptr) { + if (!extscorer->is_character_based() && extscorer->dictionary == nullptr) { // init dictionary extscorer->fill_dictionary(true); } diff --git a/models/swig_decoders/ctc_decoders.h b/models/swig_decoders/ctc_decoders.h index 5b4bb793..b8c512bd 100644 --- a/models/swig_decoders/ctc_decoders.h +++ b/models/swig_decoders/ctc_decoders.h @@ -14,12 +14,11 @@ * over vocabulary of one time step. * vocabulary: A vector of vocabulary. * Return: - * A vector that each element is a pair of score and decoding result, - * in desending order. + * The decoding result in string */ std::string ctc_greedy_decoder( - const std::vector>& probs_seq, - const std::vector& vocabulary); + const std::vector> &probs_seq, + const std::vector &vocabulary); /* CTC Beam Search Decoder @@ -37,7 +36,7 @@ std::string ctc_greedy_decoder( * in desending order. */ std::vector> ctc_beam_search_decoder( - const std::vector>& probs_seq, + const std::vector> &probs_seq, int beam_size, std::vector vocabulary, int blank_id, @@ -59,14 +58,14 @@ std::vector> ctc_beam_search_decoder( * cutoff_top_n: Cutoff number for pruning. * ext_scorer: External scorer to evaluate a prefix. * Return: - * A 2-D vector that each element is a vector of decoding result for one - * sample. + * A 2-D vector that each element is a vector of beam search decoding + * result for one audio sample. */ std::vector>> ctc_beam_search_decoder_batch( - const std::vector>>& probs_split, + const std::vector>> &probs_split, int beam_size, - const std::vector& vocabulary, + const std::vector &vocabulary, int blank_id, int num_processes, double cutoff_prob = 1.0, diff --git a/models/swig_decoders/decoder_utils.cpp b/models/swig_decoders/decoder_utils.cpp index d25c4deb..989b067e 100644 --- a/models/swig_decoders/decoder_utils.cpp +++ b/models/swig_decoders/decoder_utils.cpp @@ -4,7 +4,7 @@ #include #include -size_t get_utf8_str_len(const std::string& str) { +size_t get_utf8_str_len(const std::string &str) { size_t str_len = 0; for (char c : str) { str_len += ((c & 0xc0) != 0x80); @@ -12,7 +12,7 @@ size_t get_utf8_str_len(const std::string& str) { return str_len; } -std::vector split_utf8_str(const std::string& str) { +std::vector split_utf8_str(const std::string &str) { std::vector result; std::string out_str; @@ -31,8 +31,8 @@ std::vector split_utf8_str(const std::string& str) { return result; } -std::vector split_str(const std::string& s, - const std::string& delim) { +std::vector split_str(const std::string &s, + const std::string &delim) { std::vector result; std::size_t start = 0, delim_len = delim.size(); while (true) { @@ -51,7 +51,7 @@ std::vector split_str(const std::string& s, return result; } -bool prefix_compare(const PathTrie* x, const PathTrie* y) { +bool prefix_compare(const PathTrie *x, const PathTrie *y) { if (x->score == y->score) { if (x->character == y->character) { return false; @@ -63,8 +63,8 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y) { } } -void add_word_to_fst(const std::vector& word, - fst::StdVectorFst* dictionary) { +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary) { if (dictionary->NumStates() == 0) { fst::StdVectorFst::StateId start = dictionary->AddState(); assert(start == 0); @@ -81,16 +81,16 @@ void add_word_to_fst(const std::vector& word, } bool add_word_to_dictionary( - const std::string& word, - const std::unordered_map& char_map, + const std::string &word, + const std::unordered_map &char_map, bool add_space, int SPACE_ID, - fst::StdVectorFst* dictionary) { + fst::StdVectorFst *dictionary) { auto characters = split_utf8_str(word); std::vector int_word; - for (auto& c : characters) { + for (auto &c : characters) { if (c == " ") { int_word.push_back(SPACE_ID); } else { @@ -108,5 +108,5 @@ bool add_word_to_dictionary( } add_word_to_fst(int_word, dictionary); - return true; + return true; // return with successful adding } diff --git a/models/swig_decoders/decoder_utils.h b/models/swig_decoders/decoder_utils.h index 51985c86..d4ee36e1 100644 --- a/models/swig_decoders/decoder_utils.h +++ b/models/swig_decoders/decoder_utils.h @@ -14,12 +14,14 @@ bool pair_comp_first_rev(const std::pair &a, return a.first > b.first; } +// Function template for comparing two pairs template bool pair_comp_second_rev(const std::pair &a, const std::pair &b) { return a.second > b.second; } +// Return the sum of two probabilities in log scale template T log_sum_exp(const T &x, const T &y) { static T num_min = -std::numeric_limits::max(); @@ -32,18 +34,21 @@ T log_sum_exp(const T &x, const T &y) { // Functor for prefix comparsion bool prefix_compare(const PathTrie *x, const PathTrie *y); -// Get length of utf8 encoding string -// See: http://stackoverflow.com/a/4063229 +/* Get length of utf8 encoding string + * See: http://stackoverflow.com/a/4063229 + */ size_t get_utf8_str_len(const std::string &str); -// Split a string into a list of strings on a given string -// delimiter. NB: delimiters on beginning / end of string are -// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. +/* Split a string into a list of strings on a given string + * delimiter. NB: delimiters on beginning / end of string are + * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. + */ std::vector split_str(const std::string &s, const std::string &delim); -// Splits string into vector of strings representing -// UTF-8 characters (not same as chars) +/* Splits string into vector of strings representing + * UTF-8 characters (not same as chars) + */ std::vector split_utf8_str(const std::string &str); // Add a word in index to the dicionary of fst diff --git a/models/swig_decoders/path_trie.cpp b/models/swig_decoders/path_trie.cpp index 9e68c0f1..6a1f6170 100644 --- a/models/swig_decoders/path_trie.cpp +++ b/models/swig_decoders/path_trie.cpp @@ -22,7 +22,7 @@ PathTrie::PathTrie() { _dictionary = nullptr; _dictionary_state = 0; _has_dictionary = false; - _matcher = nullptr; // finds arcs in FST + _matcher = nullptr; } PathTrie::~PathTrie() { diff --git a/models/swig_decoders/path_trie.h b/models/swig_decoders/path_trie.h index e581ca73..6f150e42 100644 --- a/models/swig_decoders/path_trie.h +++ b/models/swig_decoders/path_trie.h @@ -10,27 +10,36 @@ using FSTMATCH = fst::SortedMatcher; +/* Trie tree for prefix storing and manipulating, with a dictionary in + * finite-state transducer for spelling correction. + */ class PathTrie { public: PathTrie(); ~PathTrie(); + // get new prefix after appending new char PathTrie* get_path_trie(int new_char, bool reset = true); + // get the prefix in index from root to current node PathTrie* get_path_vec(std::vector& output); + // get the prefix in index from some stop node to current nodel PathTrie* get_path_vec(std::vector& output, int stop, size_t max_steps = std::numeric_limits::max()); + // update log probs void iterate_to_vec(std::vector& output); + // set dictionary for FST void set_dictionary(fst::StdVectorFst* dictionary); void set_matcher(std::shared_ptr matcher); bool is_empty() { return _ROOT == character; } + // remove current path from root void remove(); float log_prob_b_prev; @@ -49,8 +58,10 @@ private: std::vector> _children; + // pointer to dictionary of FST fst::StdVectorFst* _dictionary; fst::StdVectorFst::StateId _dictionary_state; + // true if finding ars in FST std::shared_ptr _matcher; }; diff --git a/models/swig_decoders/scorer.cpp b/models/swig_decoders/scorer.cpp index a713b0df..75919c3c 100644 --- a/models/swig_decoders/scorer.cpp +++ b/models/swig_decoders/scorer.cpp @@ -68,7 +68,7 @@ double Scorer::get_log_cond_prob(const std::vector& words) { state = out_state; out_state = tmp_state; } - // log10 prob + // return log10 prob return cond_prob; } @@ -189,23 +189,26 @@ void Scorer::fill_dictionary(bool add_space) { std::cerr << "Vocab Size " << vocab_size << std::endl; - // Simplify FST + /* Simplify FST - // This gets rid of "epsilon" transitions in the FST. - // These are transitions that don't require a string input to be taken. - // Getting rid of them is necessary to make the FST determinisitc, but - // can greatly increase the size of the FST + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST determinisitc, but + * can greatly increase the size of the FST + */ fst::RmEpsilon(&dictionary); fst::StdVectorFst* new_dict = new fst::StdVectorFst; - // This makes the FST deterministic, meaning for any string input there's - // only one possible state the FST could be in. It is assumed our - // dictionary is deterministic when using it. - // (lest we'd have to check for multiple transitions at each state) + /* This makes the FST deterministic, meaning for any string input there's + * only one possible state the FST could be in. It is assumed our + * dictionary is deterministic when using it. + * (lest we'd have to check for multiple transitions at each state) + */ fst::Determinize(dictionary, new_dict); - // Finds the simplest equivalent fst. This is unnecessary but decreases - // memory usage of the dictionary + /* Finds the simplest equivalent fst. This is unnecessary but decreases + * memory usage of the dictionary + */ fst::Minimize(new_dict); this->dictionary = new_dict; } diff --git a/models/swig_decoders/scorer.h b/models/swig_decoders/scorer.h index b99a99b7..1b4857e3 100644 --- a/models/swig_decoders/scorer.h +++ b/models/swig_decoders/scorer.h @@ -23,14 +23,15 @@ class RetriveStrEnumerateVocab : public lm::EnumerateVocab { public: RetriveStrEnumerateVocab() {} - void Add(lm::WordIndex index, const StringPiece& str) { + void Add(lm::WordIndex index, const StringPiece &str) { vocabulary.push_back(std::string(str.data(), str.length())); } std::vector vocabulary; }; -/* External scorer to query languange score for n-gram or sentence. +/* External scorer to query score for n-gram or sentence, including language + * model scoring and word insertion. * * Example: * Scorer scorer(alpha, beta, "path_of_language_model"); @@ -39,12 +40,12 @@ public: */ class Scorer { public: - Scorer(double alpha, double beta, const std::string& lm_path); + Scorer(double alpha, double beta, const std::string &lm_path); ~Scorer(); - double get_log_cond_prob(const std::vector& words); + double get_log_cond_prob(const std::vector &words); - double get_sent_log_prob(const std::vector& words); + double get_sent_log_prob(const std::vector &words); size_t get_max_order() { return _max_order; } @@ -56,32 +57,32 @@ public: void reset_params(float alpha, float beta); // make ngram - std::vector make_ngram(PathTrie* prefix); + std::vector make_ngram(PathTrie *prefix); // fill dictionary for fst void fill_dictionary(bool add_space); // set char map - void set_char_map(const std::vector& char_list); + void set_char_map(const std::vector &char_list); - std::vector split_labels(const std::vector& labels); + std::vector split_labels(const std::vector &labels); // expose to decoder double alpha; double beta; // fst dictionary - void* dictionary; + void *dictionary; protected: - void load_LM(const char* filename); + void load_LM(const char *filename); - double get_log_prob(const std::vector& words); + double get_log_prob(const std::vector &words); - std::string vec2str(const std::vector& input); + std::string vec2str(const std::vector &input); private: - void* _language_model; + void *_language_model; bool _is_character_based; size_t _max_order;