From 8ff6221d00e8cc8bd5082a86d3d7f383c05b1430 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 29 Aug 2017 12:27:30 +0800 Subject: [PATCH] enable finite-state transducer in beam search decoding --- deploy.py | 8 +-- deploy/ctc_decoders.cpp | 15 +++- deploy/decoder_utils.cpp | 30 +++++++- deploy/decoder_utils.h | 4 +- deploy/scorer.cpp | 143 ++++++++++++++++++++++++++++++++++++--- deploy/scorer.h | 11 ++- 6 files changed, 189 insertions(+), 22 deletions(-) diff --git a/deploy.py b/deploy.py index 833c5c20..d43ab1e0 100644 --- a/deploy.py +++ b/deploy.py @@ -18,7 +18,7 @@ import time parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=5, + default=4, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -89,7 +89,8 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="lm/data/common_crawl_00.prune01111.trie.klm", + default="/home/work/liuyibing/lm_bak/common_crawl_00.prune01111.trie.klm", + #default="ptb_all.arpa", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -183,8 +184,7 @@ def infer(): vocabulary=data_generator.vocab_list, blank_id=len(data_generator.vocab_list), cutoff_prob=args.cutoff_prob, - # ext_scoring_func=ext_scorer, - ) + ext_scoring_func=ext_scorer, ) batch_beam_results += [beam_result] else: batch_beam_results = ctc_beam_search_decoder_batch( diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index 30e85525..d84f5b16 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -103,10 +103,13 @@ std::vector > prefixes.push_back(&root); if ( ext_scorer != nullptr && !ext_scorer->is_character_based()) { - if (ext_scorer->dictionary == nullptr) { + if (ext_scorer->_dictionary == nullptr) { // TODO: init dictionary + ext_scorer->set_char_map(vocabulary); + // add_space should be true? + ext_scorer->fill_dictionary(true); } - auto fst_dict = static_cast(ext_scorer->dictionary); + auto fst_dict = static_cast(ext_scorer->_dictionary); fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); root.set_dictionary(dict_ptr); auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); @@ -288,6 +291,14 @@ std::vector>> ThreadPool pool(num_processes); // number of samples int batch_size = probs_split.size(); + // dictionary init + if ( ext_scorer != nullptr) { + if (ext_scorer->_dictionary == nullptr) { + // TODO: init dictionary + ext_scorer->set_char_map(vocabulary); + ext_scorer->fill_dictionary(true); + } + } // enqueue the tasks of decoding std::vector>>> res; for (int i = 0; i < batch_size; i++) { diff --git a/deploy/decoder_utils.cpp b/deploy/decoder_utils.cpp index 366c8d35..0ec86d6b 100644 --- a/deploy/decoder_utils.cpp +++ b/deploy/decoder_utils.cpp @@ -11,6 +11,32 @@ size_t get_utf8_str_len(const std::string& str) { return str_len; } +//------------------------------------------------------ +//Splits string into vector of strings representing +//UTF-8 characters (not same as chars) +//------------------------------------------------------ +std::vector UTF8_split(const std::string& str) +{ + std::vector result; + std::string out_str; + + for (char c : str) + { + if ((c & 0xc0) != 0x80) //new UTF-8 character + { + if (!out_str.empty()) + { + result.push_back(out_str); + out_str.clear(); + } + } + + out_str.append(1, c); + } + result.push_back(out_str); + return result; +} + //------------------------------------------------------- // Overriding less than operator for sorting //------------------------------------------------------- @@ -49,12 +75,11 @@ void add_word_to_fst(const std::vector& word, // --------------------------------------------------------- // Adds a word to the dictionary FST based on char_map // --------------------------------------------------------- -bool addWordToDictionary(const std::string& word, +bool add_word_to_dictionary(const std::string& word, const std::unordered_map& char_map, bool add_space, int SPACE, fst::StdVectorFst* dictionary) { - /* auto characters = UTF8_split(word); std::vector int_word; @@ -77,6 +102,5 @@ bool addWordToDictionary(const std::string& word, } add_word_to_fst(int_word, dictionary); - */ return true; } // -------------- End of addWordToDictionary ------------ diff --git a/deploy/decoder_utils.h b/deploy/decoder_utils.h index d5e7d186..b61cdfbf 100644 --- a/deploy/decoder_utils.h +++ b/deploy/decoder_utils.h @@ -35,10 +35,12 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y); // See: http://stackoverflow.com/a/4063229 size_t get_utf8_str_len(const std::string& str); +std::vector UTF8_split(const std::string &str); + void add_word_to_fst(const std::vector& word, fst::StdVectorFst* dictionary); -bool addWordToDictionary(const std::string& word, +bool add_word_to_dictionary(const std::string& word, const std::unordered_map& char_map, bool add_space, int SPACE, diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index 4dc8b253..ad33a0cd 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -15,7 +15,9 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { this->beta = beta; _is_character_based = true; _language_model = nullptr; + _dictionary = nullptr; _max_order = 0; + _SPACE = -1; // load language model load_LM(lm_path.c_str()); } @@ -23,6 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { Scorer::~Scorer() { if (_language_model != nullptr) delete static_cast(_language_model); + if (_dictionary != nullptr) + delete static_cast(_dictionary); } void Scorer::load_LM(const char* filename) { @@ -176,11 +180,83 @@ double Scorer::get_score(std::string sentence, bool log) { return final_score; } -//-------------------------------------------------- -// Turn indices back into strings of chars -//-------------------------------------------------- +std::string Scorer::vec2str(const std::vector& input) { + std::string word; + for (auto ind : input) { + word += _char_list[ind]; + } + return word; +} + + +std::vector +Scorer::split_labels(const std::vector &labels) { + if (labels.empty()) + return {}; + + std::string s = vec2str(labels); + std::vector words; + if (_is_character_based) { + words = UTF8_split(s); + } else { + words = split_str(s, " "); + } + return words; +} + +// Split a string into a list of strings on a given string +// delimiter. NB: delimiters on beginning / end of string are +// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. +std::vector Scorer::split_str(const std::string &s, + const std::string &delim) { + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; + } + return result; +} + +//--------------------------------------------------- +// Add index to char list for searching language model +//--------------------------------------------------- +void Scorer::set_char_map(std::vector char_list) { + _char_list = char_list; + std::string _SPACE_STR = " "; + + for (unsigned int i = 0; i < _char_list.size(); i++) { + // if (_char_list[i] == _BLANK_STR) { + // _BLANK = i; + // } else + if (_char_list[i] == _SPACE_STR) { + _SPACE = i; + } + } + + _char_map.clear(); + for(unsigned int i = 0; i < _char_list.size(); i++) + { + if(i == (unsigned int)_SPACE){ + _char_map[' '] = i; + } + else if(_char_list[i].size() == 1){ + _char_map[_char_list[i][0]] = i; + } + } + +} //------------- End of set_char_map ---------------- + std::vector Scorer::make_ngram(PathTrie* prefix) { - /* std::vector ngram; PathTrie* current_node = prefix; PathTrie* new_node = nullptr; @@ -189,10 +265,10 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { std::vector prefix_vec; if (_is_character_based) { - new_node = current_node->get_path_vec(prefix_vec, ' ', 1); + new_node = current_node->get_path_vec(prefix_vec, _SPACE, 1); current_node = new_node; } else { - new_node = current_node->getPathVec(prefix_vec, ' '); + new_node = current_node->get_path_vec(prefix_vec, _SPACE); current_node = new_node->_parent; // Skipping spaces } @@ -202,15 +278,60 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { if (new_node->_character == -1) { // No more spaces, but still need order - for (int i = 0; i < max_order - order - 1; i++) { + for (int i = 0; i < _max_order - order - 1; i++) { ngram.push_back(""); } break; } } std::reverse(ngram.begin(), ngram.end()); - */ - std::vector ngram; - ngram.push_back("this"); return ngram; -} //---------------- End makeNgrams ------------------ +} + +//--------------------------------------------------------- +// Helper function to populate Trie with a vocab using the +// char_list for maping from string to int +//--------------------------------------------------------- +void Scorer::fill_dictionary(bool add_space) { + + fst::StdVectorFst dictionary; + // First reverse char_list so ints can be accessed by chars + std::unordered_map char_map; + for (unsigned int i = 0; i < _char_list.size(); i++) { + char_map[_char_list[i]] = i; + } + + // For each unigram convert to ints and put in trie + int vocab_size = 0; + for (const auto& word : _vocabulary) { + bool added = add_word_to_dictionary(word, + char_map, + add_space, + _SPACE, + &dictionary); + vocab_size += added ? 1 : 0; + } + + std::cerr << "Vocab Size " << vocab_size << std::endl; + + // Simplify FST + + // This gets rid of "epsilon" transitions in the FST. + // These are transitions that don't require a string input to be taken. + // Getting rid of them is necessary to make the FST determinisitc, but + // can greatly increase the size of the FST + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + // This makes the FST deterministic, meaning for any string input there's + // only one possible state the FST could be in. It is assumed our + // dictionary is deterministic when using it. + // (lest we'd have to check for multiple transitions at each state) + fst::Determinize(dictionary, new_dict); + + // Finds the simplest equivalent fst. This is unnecessary but decreases + // memory usage of the dictionary + fst::Minimize(new_dict); + _dictionary = new_dict; + +} diff --git a/deploy/scorer.h b/deploy/scorer.h index f0efbca9..9ba55dd6 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -53,15 +53,23 @@ public: double get_score(std::string, bool log=false); // make ngram std::vector make_ngram(PathTrie* prefix); + // fill dictionary for fst + void fill_dictionary(bool add_space); + // set char map + void set_char_map(std::vector char_list); // expose to decoder double alpha; double beta; // fst dictionary - void* dictionary; + void* _dictionary; protected: void load_LM(const char* filename); double get_log_prob(const std::vector& words); + std::string vec2str(const std::vector &input); + std::vector split_labels(const std::vector &labels); + std::vector split_str(const std::string &s, + const std::string &delim); private: void _init_char_list(); @@ -71,6 +79,7 @@ private: bool _is_character_based; size_t _max_order; + unsigned int _SPACE; std::vector _char_list; std::unordered_map _char_map;