From 3ee020397cafca64cace4c71123c53b4fe8999a0 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 23 Aug 2017 11:06:27 +0800 Subject: [PATCH] Refactor scorer and move utility functions to decoder_util.h --- deploy/README.md | 2 + deploy/ctc_decoders.cpp | 23 ------ deploy/decoder_utils.cpp | 7 ++ deploy/decoder_utils.h | 33 ++++++--- deploy/decoders.i | 9 ++- deploy/scorer.cpp | 148 ++++++++++++++++++--------------------- deploy/scorer.h | 69 ++++++++++++------ 7 files changed, 154 insertions(+), 137 deletions(-) diff --git a/deploy/README.md b/deploy/README.md index cf0c0439..162a396a 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -7,6 +7,8 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz ``` +Compiling for python interface requires swig, please make sure swig being installed. + Then run the setup ```shell diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index 75555c01..836fb435 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -9,29 +9,6 @@ typedef double log_prob_type; - -template -bool pair_comp_first_rev(const std::pair a, const std::pair b) -{ - return a.first > b.first; -} - -template -bool pair_comp_second_rev(const std::pair a, const std::pair b) -{ - return a.second > b.second; -} - -template -T log_sum_exp(T x, T y) -{ - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; -} - std::string ctc_best_path_decoder(std::vector > probs_seq, std::vector vocabulary) { // dimension check diff --git a/deploy/decoder_utils.cpp b/deploy/decoder_utils.cpp index 82e4cd14..d616d7c6 100644 --- a/deploy/decoder_utils.cpp +++ b/deploy/decoder_utils.cpp @@ -3,3 +3,10 @@ #include #include "decoder_utils.h" +size_t get_utf8_str_len(const std::string& str) { + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; +} diff --git a/deploy/decoder_utils.h b/deploy/decoder_utils.h index 6d58bf1f..9419e005 100644 --- a/deploy/decoder_utils.h +++ b/deploy/decoder_utils.h @@ -1,15 +1,32 @@ -#ifndef DECODER_UTILS_H -#define DECODER_UTILS_H -#pragma once +#ifndef DECODER_UTILS_H_ +#define DECODER_UTILS_H_ + #include -/* template -bool pair_comp_first_rev(const std::pair a, const std::pair b); +bool pair_comp_first_rev(const std::pair &a, const std::pair &b) +{ + return a.first > b.first; +} template -bool pair_comp_second_rev(const std::pair a, const std::pair b); +bool pair_comp_second_rev(const std::pair &a, const std::pair &b) +{ + return a.second > b.second; +} + +template +T log_sum_exp(const T &x, const T &y) +{ + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) return y; + if (y <= num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; +} + +// Get length of utf8 encoding string +// See: http://stackoverflow.com/a/4063229 +size_t get_utf8_str_len(const std::string& str); -template T log_sum_exp(T x, T y); -*/ #endif // DECODER_UTILS_H diff --git a/deploy/decoders.i b/deploy/decoders.i index 04736e09..ed7c85e6 100644 --- a/deploy/decoders.i +++ b/deploy/decoders.i @@ -2,13 +2,15 @@ %{ #include "scorer.h" #include "ctc_decoders.h" +#include "decoder_utils.h" %} %include "std_vector.i" %include "std_pair.i" %include "std_string.i" +%import "decoder_utils.h" -namespace std{ +namespace std { %template(DoubleVector) std::vector; %template(IntVector) std::vector; %template(StringVector) std::vector; @@ -19,6 +21,9 @@ namespace std{ %template(PairDoubleStringVector) std::vector >; } -%import decoder_utils.h +%template(IntDoublePairCompSecondRev) pair_comp_second_rev; +%template(StringDoublePairCompSecondRev) pair_comp_second_rev; +%template(DoubleStringPairCompFirstRev) pair_comp_first_rev; + %include "scorer.h" %include "ctc_decoders.h" diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index e9a74b98..17bb6e10 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -1,103 +1,89 @@ #include #include #include "scorer.h" -#include "lm/model.hh" -#include "util/tokenize_piece.hh" -#include "util/string_piece.hh" +#include "decoder_utils.h" -using namespace lm::ngram; - -Scorer::Scorer(float alpha, float beta, std::string lm_model_path) { - this->_alpha = alpha; - this->_beta = beta; - - if (access(lm_model_path.c_str(), F_OK) != 0) { - std::cout<<"Invalid language model path!"<_language_model = LoadVirtual(lm_model_path.c_str()); +Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { + this->alpha = alpha; + this->beta = beta; + _is_character_based = true; + _language_model = nullptr; + _max_order = 0; + // load language model + load_LM(lm_path.c_str()); } -Scorer::~Scorer(){ - delete (lm::base::Model *)this->_language_model; +Scorer::~Scorer() { + if (_language_model != nullptr) + delete static_cast(_language_model); } -/* Strip a input sentence - * Parameters: - * str: A reference to the objective string - * ch: The character to prune - * Return: - * void - */ -inline void strip(std::string &str, char ch=' ') { - if (str.size() == 0) return; - int start = 0; - int end = str.size()-1; - for (int i=0; i=0; i--) { - if (str[i] == ch) { - end --; - } else { - break; + RetriveStrEnumerateVocab enumerate; + Config config; + config.enumerate_vocab = &enumerate; + _language_model = lm::ngram::LoadVirtual(filename, config); + _max_order = static_cast(_language_model)->Order(); + _vocabulary = enumerate.vocabulary; + for (size_t i = 0; i < _vocabulary.size(); ++i) { + if (_is_character_based + && _vocabulary[i] != UNK_TOKEN + && _vocabulary[i] != START_TOKEN + && _vocabulary[i] != END_TOKEN + && get_utf8_str_len(enumerate.vocabulary[i]) > 1) { + _is_character_based = false; } } - - if (start == 0 && end == str.size()-1) return; - if (start > end) { - std::string emp_str; - str = emp_str; - } else { - str = str.substr(start, end-start+1); - } } -int Scorer::word_count(std::string sentence) { - strip(sentence); - int cnt = 1; - for (int i=0; i& words) { + lm::base::Model* model = static_cast(_language_model); + double cond_prob; + State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV + if (word_index == 0) { + return OOV_SCOER; } - } - return cnt; -} - -double Scorer::language_model_score(std::string sentence) { - lm::base::Model *model = (lm::base::Model *)this->_language_model; - State state, out_state; - lm::FullScoreReturn ret; - model->BeginSentenceWrite(&state); - - for (util::TokenIter it(sentence, ' '); it; ++it){ - lm::WordIndex wid = model->BaseVocabulary().Index(*it); - ret = model->BaseFullScore(&state, wid, &out_state); + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; state = out_state; + out_state = tmp_state; } - //log10 prob - double log_prob = ret.prob; - return log_prob; + // log10 prob + return cond_prob; } -void Scorer::reset_params(float alpha, float beta) { - this->_alpha = alpha; - this->_beta = beta; +double Scorer::get_sent_log_prob(const std::vector& words) { + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < _max_order; ++i) { + sentence.push_back(START_TOKEN); + } + } else { + for (size_t i = 0; i < _max_order - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); + } + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); } -double Scorer::get_score(std::string sentence, bool log) { - double lm_score = language_model_score(sentence); - int word_cnt = word_count(sentence); - - double final_score = 0.0; - if (log == false) { - final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); - } else { - final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt); +double Scorer::get_log_prob(const std::vector& words) { + assert(words.size() > _max_order); + double score = 0.0; + for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { + std::vector ngram(words.begin() + i, + words.begin() + i + _max_order); + score += get_log_cond_prob(ngram); } - return final_score; + return score; } diff --git a/deploy/scorer.h b/deploy/scorer.h index a18e119b..a650d375 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -2,35 +2,58 @@ #define SCORER_H_ #include +#include +#include +#include "lm/enumerate_vocab.hh" +#include "lm/word_index.hh" +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" -/* External scorer to evaluate a prefix or a complete sentence - * when a new word appended during decoding, consisting of word - * count and language model scoring. +const double OOV_SCOER = -1000.0; +const std::string START_TOKEN = ""; +const std::string UNK_TOKEN = ""; +const std::string END_TOKEN = ""; - * Example: - * Scorer ext_scorer(alpha, beta, "path_to_language_model.klm"); - * double score = ext_scorer.get_score("sentence_to_score"); - */ -class Scorer{ -private: - float _alpha; - float _beta; - void *_language_model; + // Implement a callback to retrive string vocabulary. +class RetriveStrEnumerateVocab : public lm::EnumerateVocab { +public: + RetriveStrEnumerateVocab() {} - // word insertion term - int word_count(std::string); - // n-gram language model scoring - double language_model_score(std::string); + void Add(lm::WordIndex index, const StringPiece& str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } + + std::vector vocabulary; +}; +// External scorer to query languange score for n-gram or sentence. +// Example: +// Scorer scorer(alpha, beta, "path_of_language_model"); +// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); +// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); +class Scorer{ public: - Scorer(){} - Scorer(float alpha, float beta, std::string lm_model_path); + Scorer(double alpha, double beta, const std::string& lm_path); ~Scorer(); + double get_log_cond_prob(const std::vector& words); + double get_sent_log_prob(const std::vector& words); + size_t get_max_order() { return _max_order; } + bool is_character_based() { return _is_character_based; } + std::vector get_vocab() { return _vocabulary; } + + // expose to decoder + double alpha; + double beta; - // reset params alpha & beta - void reset_params(float alpha, float beta); - // get the final score - double get_score(std::string, bool log=false); +protected: + void load_LM(const char* filename); + double get_log_prob(const std::vector& words); + +private: + void* _language_model; + bool _is_character_based; + size_t _max_order; + std::vector _vocabulary; }; -#endif //SCORER_H_ +#endif // SCORER_H_