diff --git a/deploy.py b/deploy.py index 76b61605..833c5c20 100644 --- a/deploy.py +++ b/deploy.py @@ -18,7 +18,7 @@ import time parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=32, + default=5, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -79,7 +79,7 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--beam_size", - default=200, + default=20, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( @@ -104,7 +104,7 @@ parser.add_argument( help="Parameter associated with word count. (default: %(default)f)") parser.add_argument( "--cutoff_prob", - default=0.99, + default=1.0, type=float, help="The cutoff probability of pruning" "in beam search. (default: %(default)f)") @@ -183,7 +183,8 @@ def infer(): vocabulary=data_generator.vocab_list, blank_id=len(data_generator.vocab_list), cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) + # ext_scoring_func=ext_scorer, + ) batch_beam_results += [beam_result] else: batch_beam_results = ctc_beam_search_decoder_batch( diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index fd553be6..30e85525 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -4,11 +4,13 @@ #include #include #include +#include "fst/fstlib.h" #include "ctc_decoders.h" #include "decoder_utils.h" +#include "path_trie.h" #include "ThreadPool.h" -typedef double log_prob_type; +typedef float log_prob_type; std::string ctc_best_path_decoder(std::vector > probs_seq, std::vector vocabulary) @@ -89,24 +91,30 @@ std::vector > exit(1); } - // initialize - // two sets containing selected and candidate prefixes respectively - std::map prefix_set_prev, prefix_set_next; - // probability of prefixes ending with blank and non-blank - std::map log_probs_b_prev, log_probs_nb_prev; - std::map log_probs_b_cur, log_probs_nb_cur; - - static log_prob_type NUM_MAX = std::numeric_limits::max(); - prefix_set_prev["\t"] = 0.0; - log_probs_b_prev["\t"] = 0.0; - log_probs_nb_prev["\t"] = -NUM_MAX; - - for (int time_step=0; time_step prob = probs_seq[time_step]; + static log_prob_type POS_INF = std::numeric_limits::max(); + static log_prob_type NEG_INF = -POS_INF; + static log_prob_type NUM_MIN = std::numeric_limits::min(); + + // init + PathTrie root; + root._log_prob_b_prev = 0.0; + root._score = 0.0; + std::vector prefixes; + prefixes.push_back(&root); + + if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) { + if (ext_scorer->dictionary == nullptr) { + // TODO: init dictionary + } + auto fst_dict = static_cast(ext_scorer->dictionary); + fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + for (int time_step = 0; time_step < num_time_steps; time_step++) { + std::vector prob = probs_seq[time_step]; std::vector > prob_idx; for (int i=0; i(i, prob[i])); @@ -132,113 +140,134 @@ std::vector > std::vector > log_prob_idx; for (int i=0; i - (prob_idx[i].first, log(prob_idx[i].second))); + (prob_idx[i].first, log(prob_idx[i].second + NUM_MIN))); } - // extend prefix - for (std::map::iterator - it = prefix_set_prev.begin(); - it != prefix_set_prev.end(); it++) { - std::string l = it->first; - if( prefix_set_next.find(l) == prefix_set_next.end()) { - log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX; - } + // loop over chars + for (int index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + log_prob_type log_prob_c = log_prob_idx[index].second; + //log_prob_type log_probs_prev; - for (int index=0; index_log_prob_b_cur = log_sum_exp( + prefix->_log_prob_b_cur, + log_prob_c + prefix->_score); + continue; + } + // repeated character + if (c == prefix->_character) { + prefix->_log_prob_nb_cur = log_sum_exp( + prefix->_log_prob_nb_cur, + log_prob_c + prefix->_log_prob_nb_prev + ); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = NEG_INF; + + if (c == prefix->_character + && prefix->_log_prob_b_prev > NEG_INF) { + log_p = log_prob_c + prefix->_log_prob_b_prev; + } else if (c != prefix->_character) { + log_p = log_prob_c + prefix->_score; } - if (last_char == new_char) { - log_probs_nb_cur[l_plus] = log_sum_exp( - log_probs_nb_cur[l_plus], - log_prob_c+log_probs_b_prev[l] - ); - log_probs_nb_cur[l] = log_sum_exp( - log_probs_nb_cur[l], - log_prob_c+log_probs_nb_prev[l] - ); - } else if (new_char == " ") { - float score = 0.0; - if (ext_scorer != NULL && l.size() > 1) { - score = ext_scorer->get_score(l.substr(1), true); + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based()) ) { + PathTrie *prefix_to_score = nullptr; + + // don't score the space + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; } - log_probs_prev = log_sum_exp(log_probs_b_prev[l], - log_probs_nb_prev[l]); - log_probs_nb_cur[l_plus] = log_sum_exp( - log_probs_nb_cur[l_plus], - score + log_prob_c + log_probs_prev - ); - } else { - log_probs_prev = log_sum_exp(log_probs_b_prev[l], - log_probs_nb_prev[l]); - log_probs_nb_cur[l_plus] = log_sum_exp( - log_probs_nb_cur[l_plus], - log_prob_c+log_probs_prev - ); + + double score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + score = ext_scorer->get_log_cond_prob(ngram) * + ext_scorer->alpha; + + log_p += score; + log_p += ext_scorer->beta; + } - prefix_set_next[l_plus] = log_sum_exp( - log_probs_nb_cur[l_plus], - log_probs_b_cur[l_plus] - ); + prefix_new->_log_prob_nb_cur = log_sum_exp( + prefix_new->_log_prob_nb_cur, log_p); } } - prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l], - log_probs_nb_cur[l]); + } // end of loop over chars + + prefixes.clear(); + // update log probabilities + root.iterate_to_vec(prefixes); + + // sort prefixes by score + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + + for (size_t i = beam_size; i < prefixes.size(); i++) { + prefixes[i]->remove(); + } + } + } + + for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { + double approx_ctc = prefixes[i]->_score; + + // remove word insert: + std::vector output; + prefixes[i]->get_path_vec(output); + size_t prefix_length = output.size(); + // remove language model weight: + if (ext_scorer != nullptr) { + // auto words = split_labels(output); + // approx_ctc = approx_ctc - path_length * ext_scorer->beta; + // approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha; } - log_probs_b_prev = log_probs_b_cur; - log_probs_nb_prev = log_probs_nb_cur; - std::vector > - prefix_vec_next(prefix_set_next.begin(), - prefix_set_next.end()); - std::sort(prefix_vec_next.begin(), - prefix_vec_next.end(), - pair_comp_second_rev); - int num_prefixes_next = prefix_vec_next.size(); - int k = beam_size ( - prefix_vec_next.begin(), - prefix_vec_next.begin() + k - ); + prefixes[i]->_approx_ctc = approx_ctc; } - // post processing - std::vector > beam_result; - for (std::map::iterator - it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) { - if (it->second > -NUM_MAX && it->first.size() > 1) { - log_prob_type log_prob = it->second; - std::string sentence = it->first.substr(1); - // scoring the last word - if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { - log_prob = log_prob + ext_scorer->get_score(sentence, true); - } - if (log_prob > -NUM_MAX) { - std::pair cur_result(log_prob, sentence); - beam_result.push_back(cur_result); - } + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) { + space_prefixes.push_back(prefixes[i]); } } - // sort the result and return - std::sort(beam_result.begin(), beam_result.end(), - pair_comp_first_rev); - return beam_result; -} + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector > output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (int j = 0; j < output.size(); j++) { + output_str += vocabulary[output[j]]; + } + std::pair output_pair(space_prefixes[i]->_score, + output_str); + output_vecs.emplace_back( + output_pair + ); + } + + return output_vecs; + } std::vector>> @@ -250,8 +279,7 @@ std::vector>> int num_processes, double cutoff_prob, Scorer *ext_scorer - ) -{ + ) { if (num_processes <= 0) { std::cout << "num_processes must be nonnegative!" << std::endl; exit(1); diff --git a/deploy/decoder_utils.cpp b/deploy/decoder_utils.cpp index d616d7c6..366c8d35 100644 --- a/deploy/decoder_utils.cpp +++ b/deploy/decoder_utils.cpp @@ -10,3 +10,73 @@ size_t get_utf8_str_len(const std::string& str) { } return str_len; } + +//------------------------------------------------------- +// Overriding less than operator for sorting +//------------------------------------------------------- +bool prefix_compare(const PathTrie* x, const PathTrie* y) { + if (x->_score == y->_score) { + if (x->_character == y->_character) { + return false; + } else { + return (x->_character < y->_character); + } + } else { + return x->_score > y->_score; + } +} //---------- End path_compare --------------------------- + +// -------------------------------------------------------------- +// Adds word to fst without copying entire dictionary +// -------------------------------------------------------------- +void add_word_to_fst(const std::vector& word, + fst::StdVectorFst* dictionary) { + if (dictionary->NumStates() == 0) { + fst::StdVectorFst::StateId start = dictionary->AddState(); + assert(start == 0); + dictionary->SetStart(start); + } + fst::StdVectorFst::StateId src = dictionary->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = dictionary->AddState(); + dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + dictionary->SetFinal(dst, fst::StdArc::Weight::One()); +} // ------------ End of add_word_to_fst ----------------------- + +// --------------------------------------------------------- +// Adds a word to the dictionary FST based on char_map +// --------------------------------------------------------- +bool addWordToDictionary(const std::string& word, + const std::unordered_map& char_map, + bool add_space, + int SPACE, + fst::StdVectorFst* dictionary) { + /* + auto characters = UTF8_split(word); + + std::vector int_word; + + for (auto& c : characters) { + if (c == " ") { + int_word.push_back(SPACE); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without adding + } + } + } + + if (add_space) { + int_word.push_back(SPACE); + } + + add_word_to_fst(int_word, dictionary); + */ + return true; +} // -------------- End of addWordToDictionary ------------ diff --git a/deploy/decoder_utils.h b/deploy/decoder_utils.h index 9419e005..d5e7d186 100644 --- a/deploy/decoder_utils.h +++ b/deploy/decoder_utils.h @@ -2,6 +2,7 @@ #define DECODER_UTILS_H_ #include +#include "path_trie.h" template bool pair_comp_first_rev(const std::pair &a, const std::pair &b) @@ -25,8 +26,21 @@ T log_sum_exp(const T &x, const T &y) return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; } +//------------------------------------------------------- +// Overriding less than operator for sorting +//------------------------------------------------------- +bool prefix_compare(const PathTrie* x, const PathTrie* y); + // Get length of utf8 encoding string // See: http://stackoverflow.com/a/4063229 size_t get_utf8_str_len(const std::string& str); +void add_word_to_fst(const std::vector& word, + fst::StdVectorFst* dictionary); + +bool addWordToDictionary(const std::string& word, + const std::unordered_map& char_map, + bool add_space, + int SPACE, + fst::StdVectorFst* dictionary); #endif // DECODER_UTILS_H diff --git a/deploy/path_trie.cpp b/deploy/path_trie.cpp new file mode 100644 index 00000000..6cf7ae51 --- /dev/null +++ b/deploy/path_trie.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include + +#include "path_trie.h" +#include "decoder_utils.h" + +PathTrie::PathTrie() { + float lowest = -1.0*std::numeric_limits::max(); + _log_prob_b_prev = lowest; + _log_prob_nb_prev = lowest; + _log_prob_b_cur = lowest; + _log_prob_nb_cur = lowest; + _score = lowest; + + _ROOT = -1; + _character = _ROOT; + _exists = true; + _parent = nullptr; + _dictionary = nullptr; + _dictionary_state = 0; + _has_dictionary = false; + _matcher = nullptr; // finds arcs in FST +} + +PathTrie::~PathTrie() { + for (auto child : _children) { + delete child.second; + } +} + +PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { + auto child = _children.begin(); + for (child = _children.begin(); child != _children.end(); ++child) { + if (child->first == new_char) { + break; + } + } + if ( child != _children.end() ) { + if (!child->second->_exists) { + child->second->_exists = true; + float lowest = -1.0*std::numeric_limits::max(); + child->second->_log_prob_b_prev = lowest; + child->second->_log_prob_nb_prev = lowest; + child->second->_log_prob_b_cur = lowest; + child->second->_log_prob_nb_cur = lowest; + } + return (child->second); + } else { + if (_has_dictionary) { + _matcher->SetState(_dictionary_state); + bool found = _matcher->Find(new_char); + if (!found) { + // Adding this character causes word outside dictionary + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = _dictionary->Final(_dictionary_state); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + _dictionary_state = _dictionary->Start(); + } + return nullptr; + } else { + PathTrie* new_path = new PathTrie; + new_path->_character = new_char; + new_path->_parent = this; + new_path->_dictionary = _dictionary; + new_path->_dictionary_state = _matcher->Value().nextstate; + new_path->_has_dictionary = true; + new_path->_matcher = _matcher; + _children.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } else { + PathTrie* new_path = new PathTrie; + new_path->_character = new_char; + new_path->_parent = this; + _children.push_back(std::make_pair(new_char, new_path)); + return new_path; + } + } +} + +PathTrie* PathTrie::get_path_vec(std::vector& output) { + return get_path_vec(output, _ROOT); +} + +PathTrie* PathTrie::get_path_vec(std::vector& output, + int stop, + size_t max_steps /*= std::numeric_limits::max() */) { + if (_character == stop || + _character == _ROOT || + output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + return this; + } else { + output.push_back(_character); + return _parent->get_path_vec(output, stop, max_steps); + } +} + +void PathTrie::iterate_to_vec( + std::vector& output) { + if (_exists) { + _log_prob_b_prev = _log_prob_b_cur; + _log_prob_nb_prev = _log_prob_nb_cur; + + _log_prob_b_cur = -1.0 * std::numeric_limits::max(); + _log_prob_nb_cur = -1.0 * std::numeric_limits::max(); + + _score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); + output.push_back(this); + } + for (auto child : _children) { + child.second->iterate_to_vec(output); + } +} + +//------------------------------------------------------- +// Effectively removes node +//------------------------------------------------------- +void PathTrie::remove() { + _exists = false; + + if (_children.size() == 0) { + auto child = _parent->_children.begin(); + for (child = _parent->_children.begin(); + child != _parent->_children.end(); ++child) { + if (child->first == _character) { + _parent->_children.erase(child); + break; + } + } + + if ( _parent->_children.size() == 0 && !_parent->_exists ) { + _parent->remove(); + } + + delete this; + } +} + +void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { + _dictionary = dictionary; + _dictionary_state = dictionary->Start(); + _has_dictionary = true; +} + +using FSTMATCH = fst::SortedMatcher; +void PathTrie::set_matcher(std::shared_ptr matcher) { + _matcher = matcher; +} diff --git a/deploy/path_trie.h b/deploy/path_trie.h new file mode 100644 index 00000000..7b378e3f --- /dev/null +++ b/deploy/path_trie.h @@ -0,0 +1,59 @@ +#ifndef PATH_TRIE_H +#define PATH_TRIE_H +#pragma once +#include +#include +#include +#include +#include +#include + +using FSTMATCH = fst::SortedMatcher; + +class PathTrie { +public: + PathTrie(); + ~PathTrie(); + + PathTrie* get_path_trie(int new_char, bool reset = true); + + PathTrie* get_path_vec(std::vector &output); + + PathTrie* get_path_vec(std::vector& output, + int stop, + size_t max_steps = std::numeric_limits::max()); + + void iterate_to_vec(std::vector &output); + + void set_dictionary(fst::StdVectorFst* dictionary); + + void set_matcher(std::shared_ptr matcher); + + bool is_empty() { + return _ROOT == _character; + } + + void remove(); + + float _log_prob_b_prev; + float _log_prob_nb_prev; + float _log_prob_b_cur; + float _log_prob_nb_cur; + float _score; + float _approx_ctc; + + + int _ROOT; + int _character; + bool _exists; + + PathTrie *_parent; + std::vector > _children; + + fst::StdVectorFst* _dictionary; + fst::StdVectorFst::StateId _dictionary_state; + bool _has_dictionary; + std::shared_ptr _matcher; +}; + +#endif // PATH_TRIE_H diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index a1be7e0f..4dc8b253 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -175,3 +175,42 @@ double Scorer::get_score(std::string sentence, bool log) { } return final_score; } + +//-------------------------------------------------- +// Turn indices back into strings of chars +//-------------------------------------------------- +std::vector Scorer::make_ngram(PathTrie* prefix) { + /* + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; + + for (int order = 0; order < _max_order; order++) { + std::vector prefix_vec; + + if (_is_character_based) { + new_node = current_node->get_path_vec(prefix_vec, ' ', 1); + current_node = new_node; + } else { + new_node = current_node->getPathVec(prefix_vec, ' '); + current_node = new_node->_parent; // Skipping spaces + } + + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); + + if (new_node->_character == -1) { + // No more spaces, but still need order + for (int i = 0; i < max_order - order - 1; i++) { + ngram.push_back(""); + } + break; + } + } + std::reverse(ngram.begin(), ngram.end()); + */ + std::vector ngram; + ngram.push_back("this"); + return ngram; +} //---------------- End makeNgrams ------------------ diff --git a/deploy/scorer.h b/deploy/scorer.h index a5242004..f0efbca9 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -4,10 +4,12 @@ #include #include #include +#include #include "lm/enumerate_vocab.hh" #include "lm/word_index.hh" #include "lm/virtual_interface.hh" #include "util/string_piece.hh" +#include "path_trie.h" const double OOV_SCOER = -1000.0; const std::string START_TOKEN = ""; @@ -49,18 +51,29 @@ public: void reset_params(float alpha, float beta); // get the final score double get_score(std::string, bool log=false); + // make ngram + std::vector make_ngram(PathTrie* prefix); // expose to decoder double alpha; double beta; + // fst dictionary + void* dictionary; protected: void load_LM(const char* filename); double get_log_prob(const std::vector& words); private: + void _init_char_list(); + void _init_char_map(); + void* _language_model; bool _is_character_based; size_t _max_order; + + std::vector _char_list; + std::unordered_map _char_map; + std::vector _vocabulary; };