diff --git a/decoders/swig/ctc_beam_search_decoder.cpp b/decoders/swig/ctc_beam_search_decoder.cpp index 36d16987..5c8373be 100644 --- a/decoders/swig/ctc_beam_search_decoder.cpp +++ b/decoders/swig/ctc_beam_search_decoder.cpp @@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher; std::vector> ctc_beam_search_decoder( const std::vector> &probs_seq, + const std::vector &vocabulary, size_t beam_size, - std::vector vocabulary, double cutoff_prob, size_t cutoff_top_n, Scorer *ext_scorer) { @@ -36,8 +36,7 @@ std::vector> ctc_beam_search_decoder( size_t blank_id = vocabulary.size(); // assign space id - std::vector::iterator it = - std::find(vocabulary.begin(), vocabulary.end(), " "); + auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); int space_id = it - vocabulary.begin(); // if no space in vocabulary if ((size_t)space_id >= vocabulary.size()) { @@ -173,11 +172,11 @@ std::vector> ctc_beam_search_decoder( std::vector>> ctc_beam_search_decoder_batch( const std::vector>> &probs_split, - const size_t beam_size, const std::vector &vocabulary, - const size_t num_processes, - const double cutoff_prob, - const size_t cutoff_top_n, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, Scorer *ext_scorer) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); // thread pool @@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch( for (size_t i = 0; i < batch_size; ++i) { res.emplace_back(pool.enqueue(ctc_beam_search_decoder, probs_split[i], - beam_size, vocabulary, + beam_size, cutoff_prob, cutoff_top_n, ext_scorer)); diff --git a/decoders/swig/ctc_beam_search_decoder.h b/decoders/swig/ctc_beam_search_decoder.h index c800384e..6fdd1551 100644 --- a/decoders/swig/ctc_beam_search_decoder.h +++ b/decoders/swig/ctc_beam_search_decoder.h @@ -12,8 +12,8 @@ * Parameters: * probs_seq: 2-D vector that each element is a vector of probabilities * over vocabulary of one time step. - * beam_size: The width of beam search. * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. * cutoff_prob: Cutoff probability for pruning. * cutoff_top_n: Cutoff number for pruning. * ext_scorer: External scorer to evaluate a prefix, which consists of @@ -25,8 +25,8 @@ */ std::vector> ctc_beam_search_decoder( const std::vector> &probs_seq, + const std::vector &vocabulary, size_t beam_size, - std::vector vocabulary, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, Scorer *ext_scorer = nullptr); @@ -36,9 +36,8 @@ std::vector> ctc_beam_search_decoder( * Parameters: * probs_seq: 3-D vector that each element is a 2-D vector that can be used * by ctc_beam_search_decoder(). - * . - * beam_size: The width of beam search. * vocabulary: A vector of vocabulary. + * beam_size: The width of beam search. * num_processes: Number of threads for beam search. * cutoff_prob: Cutoff probability for pruning. * cutoff_top_n: Cutoff number for pruning. @@ -52,8 +51,8 @@ std::vector> ctc_beam_search_decoder( std::vector>> ctc_beam_search_decoder_batch( const std::vector>> &probs_split, - size_t beam_size, const std::vector &vocabulary, + size_t beam_size, size_t num_processes, double cutoff_prob = 1.0, size_t cutoff_top_n = 40, diff --git a/decoders/swig/path_trie.cpp b/decoders/swig/path_trie.cpp index 6a1f6170..fdff3286 100644 --- a/decoders/swig/path_trie.cpp +++ b/decoders/swig/path_trie.cpp @@ -15,32 +15,32 @@ PathTrie::PathTrie() { log_prob_nb_cur = -NUM_FLT_INF; score = -NUM_FLT_INF; - _ROOT = -1; - character = _ROOT; - _exists = true; + ROOT_ = -1; + character = ROOT_; + exists_ = true; parent = nullptr; - _dictionary = nullptr; - _dictionary_state = 0; - _has_dictionary = false; - _matcher = nullptr; + dictionary_ = nullptr; + dictionary_state_ = 0; + has_dictionary_ = false; + matcher_ = nullptr; } PathTrie::~PathTrie() { - for (auto child : _children) { + for (auto child : children_) { delete child.second; } } PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = _children.begin(); - for (child = _children.begin(); child != _children.end(); ++child) { + auto child = children_.begin(); + for (child = children_.begin(); child != children_.end(); ++child) { if (child->first == new_char) { break; } } - if (child != _children.end()) { - if (!child->second->_exists) { - child->second->_exists = true; + if (child != children_.end()) { + if (!child->second->exists_) { + child->second->exists_ = true; child->second->log_prob_b_prev = -NUM_FLT_INF; child->second->log_prob_nb_prev = -NUM_FLT_INF; child->second->log_prob_b_cur = -NUM_FLT_INF; @@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { } return (child->second); } else { - if (_has_dictionary) { - _matcher->SetState(_dictionary_state); - bool found = _matcher->Find(new_char); + if (has_dictionary_) { + matcher_->SetState(dictionary_state_); + bool found = matcher_->Find(new_char); if (!found) { // Adding this character causes word outside dictionary auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = _dictionary->Final(_dictionary_state); + auto final_weight = dictionary_->Final(dictionary_state_); bool is_final = (final_weight != FSTZERO); if (is_final && reset) { - _dictionary_state = _dictionary->Start(); + dictionary_state_ = dictionary_->Start(); } return nullptr; } else { PathTrie* new_path = new PathTrie; new_path->character = new_char; new_path->parent = this; - new_path->_dictionary = _dictionary; - new_path->_dictionary_state = _matcher->Value().nextstate; - new_path->_has_dictionary = true; - new_path->_matcher = _matcher; - _children.push_back(std::make_pair(new_char, new_path)); + new_path->dictionary_ = dictionary_; + new_path->dictionary_state_ = matcher_->Value().nextstate; + new_path->has_dictionary_ = true; + new_path->matcher_ = matcher_; + children_.push_back(std::make_pair(new_char, new_path)); return new_path; } } else { PathTrie* new_path = new PathTrie; new_path->character = new_char; new_path->parent = this; - _children.push_back(std::make_pair(new_char, new_path)); + children_.push_back(std::make_pair(new_char, new_path)); return new_path; } } } PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, _ROOT); + return get_path_vec(output, ROOT_); } PathTrie* PathTrie::get_path_vec(std::vector& output, int stop, size_t max_steps) { - if (character == stop || character == _ROOT || output.size() == max_steps) { + if (character == stop || character == ROOT_ || output.size() == max_steps) { std::reverse(output.begin(), output.end()); return this; } else { @@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector& output, } void PathTrie::iterate_to_vec(std::vector& output) { - if (_exists) { + if (exists_) { log_prob_b_prev = log_prob_b_cur; log_prob_nb_prev = log_prob_nb_cur; @@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector& output) { score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); output.push_back(this); } - for (auto child : _children) { + for (auto child : children_) { child.second->iterate_to_vec(output); } } void PathTrie::remove() { - _exists = false; + exists_ = false; - if (_children.size() == 0) { - auto child = parent->_children.begin(); - for (child = parent->_children.begin(); child != parent->_children.end(); + if (children_.size() == 0) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); child != parent->children_.end(); ++child) { if (child->first == character) { - parent->_children.erase(child); + parent->children_.erase(child); break; } } - if (parent->_children.size() == 0 && !parent->_exists) { + if (parent->children_.size() == 0 && !parent->exists_) { parent->remove(); } @@ -135,12 +135,12 @@ void PathTrie::remove() { } void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - _dictionary = dictionary; - _dictionary_state = dictionary->Start(); - _has_dictionary = true; + dictionary_ = dictionary; + dictionary_state_ = dictionary->Start(); + has_dictionary_ = true; } using FSTMATCH = fst::SortedMatcher; void PathTrie::set_matcher(std::shared_ptr matcher) { - _matcher = matcher; + matcher_ = matcher; } diff --git a/decoders/swig/path_trie.h b/decoders/swig/path_trie.h index b4f5bc4b..7fd715d2 100644 --- a/decoders/swig/path_trie.h +++ b/decoders/swig/path_trie.h @@ -36,7 +36,7 @@ public: void set_matcher(std::shared_ptr>); - bool is_empty() { return _ROOT == character; } + bool is_empty() { return ROOT_ == character; } // remove current path from root void remove(); @@ -51,17 +51,17 @@ public: PathTrie* parent; private: - int _ROOT; - bool _exists; - bool _has_dictionary; + int ROOT_; + bool exists_; + bool has_dictionary_; - std::vector> _children; + std::vector> children_; // pointer to dictionary of FST - fst::StdVectorFst* _dictionary; - fst::StdVectorFst::StateId _dictionary_state; + fst::StdVectorFst* dictionary_; + fst::StdVectorFst::StateId dictionary_state_; // true if finding ars in FST - std::shared_ptr> _matcher; + std::shared_ptr> matcher_; }; #endif // PATH_TRIE_H diff --git a/decoders/swig/scorer.cpp b/decoders/swig/scorer.cpp index 6b280344..27c31fa7 100644 --- a/decoders/swig/scorer.cpp +++ b/decoders/swig/scorer.cpp @@ -19,19 +19,19 @@ Scorer::Scorer(double alpha, const std::vector& vocab_list) { this->alpha = alpha; this->beta = beta; - _is_character_based = true; - _language_model = nullptr; + is_character_based_ = true; + language_model_ = nullptr; dictionary = nullptr; - _max_order = 0; - _dict_size = 0; - _SPACE_ID = -1; + max_order_ = 0; + dict_size_ = 0; + SPACE_ID_ = -1; setup(lm_path, vocab_list); } Scorer::~Scorer() { - if (_language_model != nullptr) { - delete static_cast(_language_model); + if (language_model_ != nullptr) { + delete static_cast(language_model_); } if (dictionary != nullptr) { delete static_cast(dictionary); @@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) { RetriveStrEnumerateVocab enumerate; lm::ngram::Config config; config.enumerate_vocab = &enumerate; - _language_model = lm::ngram::LoadVirtual(filename, config); - _max_order = static_cast(_language_model)->Order(); - _vocabulary = enumerate.vocabulary; - for (size_t i = 0; i < _vocabulary.size(); ++i) { - if (_is_character_based && _vocabulary[i] != UNK_TOKEN && - _vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN && + language_model_ = lm::ngram::LoadVirtual(filename, config); + max_order_ = static_cast(language_model_)->Order(); + vocabulary_ = enumerate.vocabulary; + for (size_t i = 0; i < vocabulary_.size(); ++i) { + if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && + vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - _is_character_based = false; + is_character_based_ = false; } } } double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(_language_model); + lm::base::Model* model = static_cast(language_model_); double cond_prob; lm::ngram::State state, tmp_state, out_state; // avoid to inserting in begin @@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector& words) { double Scorer::get_sent_log_prob(const std::vector& words) { std::vector sentence; if (words.size() == 0) { - for (size_t i = 0; i < _max_order; ++i) { + for (size_t i = 0; i < max_order_; ++i) { sentence.push_back(START_TOKEN); } } else { - for (size_t i = 0; i < _max_order - 1; ++i) { + for (size_t i = 0; i < max_order_ - 1; ++i) { sentence.push_back(START_TOKEN); } sentence.insert(sentence.end(), words.begin(), words.end()); @@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector& words) { } double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > _max_order); + assert(words.size() > max_order_); double score = 0.0; - for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { + for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { std::vector ngram(words.begin() + i, - words.begin() + i + _max_order); + words.begin() + i + max_order_); score += get_log_cond_prob(ngram); } return score; @@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) { std::string Scorer::vec2str(const std::vector& input) { std::string word; for (auto ind : input) { - word += _char_list[ind]; + word += char_list_[ind]; } return word; } @@ -135,7 +135,7 @@ std::vector Scorer::split_labels(const std::vector& labels) { std::string s = vec2str(labels); std::vector words; - if (_is_character_based) { + if (is_character_based_) { words = split_utf8_str(s); } else { words = split_str(s, " "); @@ -144,15 +144,15 @@ std::vector Scorer::split_labels(const std::vector& labels) { } void Scorer::set_char_map(const std::vector& char_list) { - _char_list = char_list; - _char_map.clear(); - - for (unsigned int i = 0; i < _char_list.size(); i++) { - if (_char_list[i] == " ") { - _SPACE_ID = i; - _char_map[' '] = i; - } else if (_char_list[i].size() == 1) { - _char_map[_char_list[i][0]] = i; + char_list_ = char_list; + char_map_.clear(); + + for (size_t i = 0; i < char_list_.size(); i++) { + if (char_list_[i] == " ") { + SPACE_ID_ = i; + char_map_[' '] = i; + } else if (char_list_[i].size() == 1) { + char_map_[char_list_[i][0]] = i; } } } @@ -162,14 +162,14 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { PathTrie* current_node = prefix; PathTrie* new_node = nullptr; - for (int order = 0; order < _max_order; order++) { + for (int order = 0; order < max_order_; order++) { std::vector prefix_vec; - if (_is_character_based) { - new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); + if (is_character_based_) { + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1); current_node = new_node; } else { - new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); + new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_); current_node = new_node->parent; // Skipping spaces } @@ -179,7 +179,7 @@ std::vector Scorer::make_ngram(PathTrie* prefix) { if (new_node->character == -1) { // No more spaces, but still need order - for (int i = 0; i < _max_order - order - 1; i++) { + for (int i = 0; i < max_order_ - order - 1; i++) { ngram.push_back(START_TOKEN); } break; @@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) { fst::StdVectorFst dictionary; // First reverse char_list so ints can be accessed by chars std::unordered_map char_map; - for (unsigned int i = 0; i < _char_list.size(); i++) { - char_map[_char_list[i]] = i; + for (size_t i = 0; i < char_list_.size(); i++) { + char_map[char_list_[i]] = i; } // For each unigram convert to ints and put in trie int dict_size = 0; - for (const auto& word : _vocabulary) { + for (const auto& word : vocabulary_) { bool added = add_word_to_dictionary( - word, char_map, add_space, _SPACE_ID, &dictionary); + word, char_map, add_space, SPACE_ID_, &dictionary); dict_size += added ? 1 : 0; } - _dict_size = dict_size; + dict_size_ = dict_size; /* Simplify FST diff --git a/decoders/swig/scorer.h b/decoders/swig/scorer.h index 72544da7..61836463 100644 --- a/decoders/swig/scorer.h +++ b/decoders/swig/scorer.h @@ -18,7 +18,7 @@ const std::string START_TOKEN = ""; const std::string UNK_TOKEN = ""; const std::string END_TOKEN = ""; -// Implement a callback to retrive string vocabulary. +// Implement a callback to retrive the dictionary of language model. class RetriveStrEnumerateVocab : public lm::EnumerateVocab { public: RetriveStrEnumerateVocab() {} @@ -50,13 +50,14 @@ public: double get_sent_log_prob(const std::vector &words); - size_t get_max_order() const { return _max_order; } + // return the max order + size_t get_max_order() const { return max_order_; } - size_t get_dict_size() const { return _dict_size; } + // return the dictionary size of language model + size_t get_dict_size() const { return dict_size_; } - bool is_char_map_empty() const { return _char_map.size() == 0; } - - bool is_character_based() const { return _is_character_based; } + // retrun true if the language model is character based + bool is_character_based() const { return is_character_based_; } // reset params alpha & beta void reset_params(float alpha, float beta); @@ -68,20 +69,23 @@ public: // the vector of characters (character based lm) std::vector split_labels(const std::vector &labels); - // expose to decoder + // language model weight double alpha; + // word insertion weight double beta; - // fst dictionary + // pointer to the dictionary of FST void *dictionary; protected: + // necessary setup: load language model, set char map, fill FST's dictionary void setup(const std::string &lm_path, const std::vector &vocab_list); + // load language model from given path void load_lm(const std::string &lm_path); - // fill dictionary for fst + // fill dictionary for FST void fill_dictionary(bool add_space); // set char map @@ -89,19 +93,20 @@ protected: double get_log_prob(const std::vector &words); + // translate the vector in index to string std::string vec2str(const std::vector &input); private: - void *_language_model; - bool _is_character_based; - size_t _max_order; - size_t _dict_size; + void *language_model_; + bool is_character_based_; + size_t max_order_; + size_t dict_size_; - int _SPACE_ID; - std::vector _char_list; - std::unordered_map _char_map; + int SPACE_ID_; + std::vector char_list_; + std::unordered_map char_map_; - std::vector _vocabulary; + std::vector vocabulary_; }; #endif // SCORER_H_ diff --git a/decoders/swig_wrapper.py b/decoders/swig_wrapper.py index 5ebcd133..0a921125 100644 --- a/decoders/swig_wrapper.py +++ b/decoders/swig_wrapper.py @@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary): def ctc_beam_search_decoder(probs_seq, - beam_size, vocabulary, + beam_size, cutoff_prob=1.0, cutoff_top_n=40, ext_scoring_func=None): @@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq, step, with each element being a list of normalized probabilities over vocabulary and blank. :type probs_seq: 2-D list - :param beam_size: Width for beam search. - :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param beam_size: Width for beam search. + :type beam_size: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float @@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq, results, in descending order of the probability. :rtype: list """ - return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size, - vocabulary, cutoff_prob, + return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary, + beam_size, cutoff_prob, cutoff_top_n, ext_scoring_func) def ctc_beam_search_decoder_batch(probs_split, - beam_size, vocabulary, + beam_size, num_processes, cutoff_prob=1.0, cutoff_top_n=40, @@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split, :param probs_seq: 3-D list with each element as an instance of 2-D list of probabilities used by ctc_beam_search_decoder(). :type probs_seq: 3-D list - :param beam_size: Width for beam search. - :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param beam_size: Width for beam search. + :type beam_size: int :param num_processes: Number of parallel processes. :type num_processes: int :param cutoff_prob: Cutoff probability in vocabulary pruning, @@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split, probs_split = [probs_seq.tolist() for probs_seq in probs_split] return swig_decoders.ctc_beam_search_decoder_batch( - probs_split, beam_size, vocabulary, num_processes, cutoff_prob, + probs_split, vocabulary, beam_size, num_processes, cutoff_prob, cutoff_top_n, ext_scoring_func)