From 9a79b41bcdd2262590fd3d14daf91731430e42e1 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 29 Aug 2017 18:54:15 +0800 Subject: [PATCH] streamline source code --- deploy/ctc_decoders.cpp | 67 +++++++++++++++++----------------------- deploy/decoder_utils.cpp | 27 ++++++++++++++-- deploy/decoder_utils.h | 19 ++++++++---- deploy/path_trie.cpp | 27 +++++++--------- deploy/scorer.cpp | 65 +++++++------------------------------- deploy/scorer.h | 9 ++---- 6 files changed, 92 insertions(+), 122 deletions(-) diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index d84f5b16..da37708a 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -10,8 +10,6 @@ #include "path_trie.h" #include "ThreadPool.h" -typedef float log_prob_type; - std::string ctc_best_path_decoder(std::vector > probs_seq, std::vector vocabulary) { @@ -19,8 +17,8 @@ std::string ctc_best_path_decoder(std::vector > probs_seq, int num_time_steps = probs_seq.size(); for (int i=0; i > probs_seq, std::vector max_idx_vec; double max_prob = 0.0; int max_idx = 0; - for (int i=0; i > probs_seq, } std::vector idx_vec; - for (int i=0; i0) && max_idx_vec[i]!=max_idx_vec[i-1])) { + for (int i = 0; i < max_idx_vec.size(); i++) { + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) { idx_vec.push_back(max_idx_vec[i]); } } std::string best_path_result; - for (int i=0; i > { // dimension check int num_time_steps = probs_seq.size(); - for (int i=0; i > std::vector::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it - vocabulary.begin(); + // if no space in vocabulary if(space_id >= vocabulary.size()) { - std::cout << " The character space is not in the vocabulary!"<::max(); - static log_prob_type NEG_INF = -POS_INF; - static log_prob_type NUM_MIN = std::numeric_limits::min(); - // init PathTrie root; - root._log_prob_b_prev = 0.0; - root._score = 0.0; + root._score = root._log_prob_b_prev = 0.0; std::vector prefixes; prefixes.push_back(&root); @@ -140,17 +133,17 @@ std::vector > prob_idx.begin() + cutoff_len); } - std::vector > log_prob_idx; - for (int i=0; i - (prob_idx[i].first, log(prob_idx[i].second + NUM_MIN))); + std::vector > log_prob_idx; + for (int i = 0; i < cutoff_len; i++) { + log_prob_idx.push_back(std::pair + (prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); } // loop over chars for (int index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; - log_prob_type log_prob_c = log_prob_idx[index].second; - //log_prob_type log_probs_prev; + float log_prob_c = log_prob_idx[index].second; + //float log_probs_prev; for (int i = 0; i < prefixes.size() && i > if (c == prefix->_character) { prefix->_log_prob_nb_cur = log_sum_exp( prefix->_log_prob_nb_cur, - log_prob_c + prefix->_log_prob_nb_prev - ); + log_prob_c + prefix->_log_prob_nb_prev); } // get new prefix auto prefix_new = prefix->get_path_trie(c); if (prefix_new != nullptr) { - float log_p = NEG_INF; + float log_p = -NUM_FLT_INF; if (c == prefix->_character - && prefix->_log_prob_b_prev > NEG_INF) { + && prefix->_log_prob_b_prev > -NUM_FLT_INF) { log_p = log_prob_c + prefix->_log_prob_b_prev; } else if (c != prefix->_character) { log_p = log_prob_c + prefix->_score; @@ -201,7 +193,6 @@ std::vector > log_p += score; log_p += ext_scorer->beta; - } prefix_new->_log_prob_nb_cur = log_sum_exp( prefix_new->_log_prob_nb_cur, log_p); @@ -273,7 +264,7 @@ std::vector > } -std::vector>> +std::vector > > ctc_beam_search_decoder_batch( std::vector>> probs_split, int beam_size, @@ -292,12 +283,12 @@ std::vector>> // number of samples int batch_size = probs_split.size(); // dictionary init - if ( ext_scorer != nullptr) { - if (ext_scorer->_dictionary == nullptr) { - // TODO: init dictionary - ext_scorer->set_char_map(vocabulary); - ext_scorer->fill_dictionary(true); - } + if ( ext_scorer != nullptr + && !ext_scorer->is_character_based() + && ext_scorer->_dictionary == nullptr) { + // init dictionary + ext_scorer->set_char_map(vocabulary); + ext_scorer->fill_dictionary(true); } // enqueue the tasks of decoding std::vector>>> res; @@ -308,7 +299,7 @@ std::vector>> ); } // get decoding results - std::vector>> batch_results; + std::vector > > batch_results; for (int i = 0; i < batch_size; i++) { batch_results.emplace_back(res[i].get()); } diff --git a/deploy/decoder_utils.cpp b/deploy/decoder_utils.cpp index 0ec86d6b..39beb811 100644 --- a/deploy/decoder_utils.cpp +++ b/deploy/decoder_utils.cpp @@ -15,7 +15,7 @@ size_t get_utf8_str_len(const std::string& str) { //Splits string into vector of strings representing //UTF-8 characters (not same as chars) //------------------------------------------------------ -std::vector UTF8_split(const std::string& str) +std::vector split_utf8_str(const std::string& str) { std::vector result; std::string out_str; @@ -37,6 +37,29 @@ std::vector UTF8_split(const std::string& str) return result; } +// Split a string into a list of strings on a given string +// delimiter. NB: delimiters on beginning / end of string are +// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. +std::vector split_str(const std::string &s, + const std::string &delim) { + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; + } + return result; +} + //------------------------------------------------------- // Overriding less than operator for sorting //------------------------------------------------------- @@ -80,7 +103,7 @@ bool add_word_to_dictionary(const std::string& word, bool add_space, int SPACE, fst::StdVectorFst* dictionary) { - auto characters = UTF8_split(word); + auto characters = split_utf8_str(word); std::vector int_word; diff --git a/deploy/decoder_utils.h b/deploy/decoder_utils.h index b61cdfbf..93660586 100644 --- a/deploy/decoder_utils.h +++ b/deploy/decoder_utils.h @@ -4,14 +4,19 @@ #include #include "path_trie.h" +const float NUM_FLT_INF = std::numeric_limits::max(); +const float NUM_FLT_MIN = std::numeric_limits::min(); + template -bool pair_comp_first_rev(const std::pair &a, const std::pair &b) +bool pair_comp_first_rev(const std::pair &a, + const std::pair &b) { return a.first > b.first; } template -bool pair_comp_second_rev(const std::pair &a, const std::pair &b) +bool pair_comp_second_rev(const std::pair &a, + const std::pair &b) { return a.second > b.second; } @@ -26,16 +31,18 @@ T log_sum_exp(const T &x, const T &y) return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; } -//------------------------------------------------------- -// Overriding less than operator for sorting -//------------------------------------------------------- + +// Functor for prefix comparsion bool prefix_compare(const PathTrie* x, const PathTrie* y); // Get length of utf8 encoding string // See: http://stackoverflow.com/a/4063229 size_t get_utf8_str_len(const std::string& str); -std::vector UTF8_split(const std::string &str); +std::vector split_str(const std::string &s, + const std::string &delim); + +std::vector split_utf8_str(const std::string &str); void add_word_to_fst(const std::vector& word, fst::StdVectorFst* dictionary); diff --git a/deploy/path_trie.cpp b/deploy/path_trie.cpp index 6cf7ae51..b841831d 100644 --- a/deploy/path_trie.cpp +++ b/deploy/path_trie.cpp @@ -8,12 +8,11 @@ #include "decoder_utils.h" PathTrie::PathTrie() { - float lowest = -1.0*std::numeric_limits::max(); - _log_prob_b_prev = lowest; - _log_prob_nb_prev = lowest; - _log_prob_b_cur = lowest; - _log_prob_nb_cur = lowest; - _score = lowest; + _log_prob_b_prev = -NUM_FLT_INF; + _log_prob_nb_prev = -NUM_FLT_INF; + _log_prob_b_cur = -NUM_FLT_INF; + _log_prob_nb_cur = -NUM_FLT_INF; + _score = -NUM_FLT_INF; _ROOT = -1; _character = _ROOT; @@ -41,11 +40,10 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { if ( child != _children.end() ) { if (!child->second->_exists) { child->second->_exists = true; - float lowest = -1.0*std::numeric_limits::max(); - child->second->_log_prob_b_prev = lowest; - child->second->_log_prob_nb_prev = lowest; - child->second->_log_prob_b_cur = lowest; - child->second->_log_prob_nb_cur = lowest; + child->second->_log_prob_b_prev = -NUM_FLT_INF; + child->second->_log_prob_nb_prev = -NUM_FLT_INF; + child->second->_log_prob_b_cur = -NUM_FLT_INF; + child->second->_log_prob_nb_cur = -NUM_FLT_INF; } return (child->second); } else { @@ -106,8 +104,8 @@ void PathTrie::iterate_to_vec( _log_prob_b_prev = _log_prob_b_cur; _log_prob_nb_prev = _log_prob_nb_cur; - _log_prob_b_cur = -1.0 * std::numeric_limits::max(); - _log_prob_nb_cur = -1.0 * std::numeric_limits::max(); + _log_prob_b_cur = -NUM_FLT_INF; + _log_prob_nb_cur = -NUM_FLT_INF; _score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); output.push_back(this); @@ -117,9 +115,6 @@ void PathTrie::iterate_to_vec( } } -//------------------------------------------------------- -// Effectively removes node -//------------------------------------------------------- void PathTrie::remove() { _exists = false; diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index ad33a0cd..41f3894a 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -17,7 +17,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { _language_model = nullptr; _dictionary = nullptr; _max_order = 0; - _SPACE = -1; + _SPACE_ID = -1; // load language model load_LM(lm_path.c_str()); } @@ -61,7 +61,7 @@ double Scorer::get_log_cond_prob(const std::vector& words) { lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); // encounter OOV if (word_index == 0) { - return OOV_SCOER; + return OOV_SCORE; } cond_prob = model->BaseScore(&state, word_index, &out_state); tmp_state = state; @@ -197,64 +197,27 @@ Scorer::split_labels(const std::vector &labels) { std::string s = vec2str(labels); std::vector words; if (_is_character_based) { - words = UTF8_split(s); + words = split_utf8_str(s); } else { words = split_str(s, " "); } return words; } -// Split a string into a list of strings on a given string -// delimiter. NB: delimiters on beginning / end of string are -// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. -std::vector Scorer::split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; - } - return result; -} - -//--------------------------------------------------- -// Add index to char list for searching language model -//--------------------------------------------------- void Scorer::set_char_map(std::vector char_list) { _char_list = char_list; - std::string _SPACE_STR = " "; - - for (unsigned int i = 0; i < _char_list.size(); i++) { - // if (_char_list[i] == _BLANK_STR) { - // _BLANK = i; - // } else - if (_char_list[i] == _SPACE_STR) { - _SPACE = i; - } - } - _char_map.clear(); + for(unsigned int i = 0; i < _char_list.size(); i++) { - if(i == (unsigned int)_SPACE){ + if (_char_list[i] == " ") { + _SPACE_ID = i; _char_map[' '] = i; - } - else if(_char_list[i].size() == 1){ + } else if(_char_list[i].size() == 1){ _char_map[_char_list[i][0]] = i; } } - -} //------------- End of set_char_map ---------------- +} std::vector Scorer::make_ngram(PathTrie* prefix) { std::vector ngram; @@ -265,10 +228,10 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { std::vector prefix_vec; if (_is_character_based) { - new_node = current_node->get_path_vec(prefix_vec, _SPACE, 1); + new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); current_node = new_node; } else { - new_node = current_node->get_path_vec(prefix_vec, _SPACE); + new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); current_node = new_node->_parent; // Skipping spaces } @@ -279,7 +242,7 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { if (new_node->_character == -1) { // No more spaces, but still need order for (int i = 0; i < _max_order - order - 1; i++) { - ngram.push_back(""); + ngram.push_back(START_TOKEN); } break; } @@ -288,10 +251,6 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { return ngram; } -//--------------------------------------------------------- -// Helper function to populate Trie with a vocab using the -// char_list for maping from string to int -//--------------------------------------------------------- void Scorer::fill_dictionary(bool add_space) { fst::StdVectorFst dictionary; @@ -307,7 +266,7 @@ void Scorer::fill_dictionary(bool add_space) { bool added = add_word_to_dictionary(word, char_map, add_space, - _SPACE, + _SPACE_ID, &dictionary); vocab_size += added ? 1 : 0; } diff --git a/deploy/scorer.h b/deploy/scorer.h index 9ba55dd6..17a5f1aa 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -11,7 +11,7 @@ #include "util/string_piece.hh" #include "path_trie.h" -const double OOV_SCOER = -1000.0; +const double OOV_SCORE = -1000.0; const std::string START_TOKEN = ""; const std::string UNK_TOKEN = ""; const std::string END_TOKEN = ""; @@ -68,18 +68,13 @@ protected: double get_log_prob(const std::vector& words); std::string vec2str(const std::vector &input); std::vector split_labels(const std::vector &labels); - std::vector split_str(const std::string &s, - const std::string &delim); private: - void _init_char_list(); - void _init_char_map(); - void* _language_model; bool _is_character_based; size_t _max_order; - unsigned int _SPACE; + int _SPACE_ID; std::vector _char_list; std::unordered_map _char_map;