From eef364d17c3d8e4402d95960153ebd49d539b594 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 23 Aug 2017 16:57:25 +0800 Subject: [PATCH] adapt to the last three commits --- deploy/README.md | 2 +- deploy/scorer.cpp | 85 +++++++++++++++++++++++++++++++++++++++++++++++ deploy/scorer.h | 10 +++++- 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/deploy/README.md b/deploy/README.md index 90809ad3..9f2be76e 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -14,7 +14,7 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz tar -xzvf openfst-1.6.3.tar.gz ``` -- [**swig**]: Compiling for python interface requires swig, please make sure swig being installed. +- [**SWIG**](http://www.swig.org): Compiling for python interface requires swig, please make sure swig being installed. - [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index 233b4766..a1be7e0f 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -3,9 +3,13 @@ #include "lm/config.hh" #include "lm/state.hh" #include "lm/model.hh" +#include "util/tokenize_piece.hh" +#include "util/string_piece.hh" #include "scorer.h" #include "decoder_utils.h" +using namespace lm::ngram; + Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { this->alpha = alpha; this->beta = beta; @@ -90,3 +94,84 @@ double Scorer::get_log_prob(const std::vector& words) { } return score; } + +/* Strip a input sentence + * Parameters: + * str: A reference to the objective string + * ch: The character to prune + * Return: + * void + */ +inline void strip(std::string &str, char ch=' ') { + if (str.size() == 0) return; + int start = 0; + int end = str.size()-1; + for (int i=0; i=0; i--) { + if (str[i] == ch) { + end --; + } else { + break; + } + } + + if (start == 0 && end == str.size()-1) return; + if (start > end) { + std::string emp_str; + str = emp_str; + } else { + str = str.substr(start, end-start+1); + } +} + +int Scorer::word_count(std::string sentence) { + strip(sentence); + int cnt = 1; + for (int i=0; i_language_model; + State state, out_state; + lm::FullScoreReturn ret; + model->BeginSentenceWrite(&state); + + for (util::TokenIter it(sentence, ' '); it; ++it){ + lm::WordIndex wid = model->BaseVocabulary().Index(*it); + ret = model->BaseFullScore(&state, wid, &out_state); + state = out_state; + } + //log10 prob + double log_prob = ret.prob; + return log_prob; +} + +void Scorer::reset_params(float alpha, float beta) { + this->alpha = alpha; + this->beta = beta; +} + +double Scorer::get_score(std::string sentence, bool log) { + double lm_score = get_log_cond_prob(sentence); + int word_cnt = word_count(sentence); + + double final_score = 0.0; + if (log == false) { + final_score = pow(10, alpha * lm_score) * pow(word_cnt, beta); + } else { + final_score = alpha * lm_score * std::log(10) + + beta * std::log(word_cnt); + } + return final_score; +} diff --git a/deploy/scorer.h b/deploy/scorer.h index a650d375..a5242004 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -30,6 +30,7 @@ public: // Example: // Scorer scorer(alpha, beta, "path_of_language_model"); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); +// scorer.get_log_cond_prob("this a sentence"); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); class Scorer{ public: @@ -40,7 +41,14 @@ public: size_t get_max_order() { return _max_order; } bool is_character_based() { return _is_character_based; } std::vector get_vocab() { return _vocabulary; } - + // word insertion term + int word_count(std::string); + // get the log cond prob of the last word + double get_log_cond_prob(std::string); + // reset params alpha & beta + void reset_params(float alpha, float beta); + // get the final score + double get_score(std::string, bool log=false); // expose to decoder double alpha; double beta;