adapt to the last three commits

pull/2/head
Yibing Liu 7 years ago
parent 8dc0b2b0b0
commit eef364d17c

@ -14,7 +14,7 @@ wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar -xzvf openfst-1.6.3.tar.gz
```
- [**swig**]: Compiling for python interface requires swig, please make sure swig being installed.
- [**SWIG**](http://www.swig.org): Compiling for python interface requires swig, please make sure swig being installed.
- [**ThreadPool**](http://progsch.net/wordpress/): A library for C++ thread pool

@ -3,9 +3,13 @@
#include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
#include "scorer.h"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this->alpha = alpha;
this->beta = beta;
@ -90,3 +94,84 @@ double Scorer::get_log_prob(const std::vector<std::string>& words) {
}
return score;
}
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline void strip(std::string &str, char ch=' ') {
if (str.size() == 0) return;
int start = 0;
int end = str.size()-1;
for (int i=0; i<str.size(); i++){
if (str[i] == ch) {
start ++;
} else {
break;
}
}
for (int i=str.size()-1; i>=0; i--) {
if (str[i] == ch) {
end --;
} else {
break;
}
}
if (start == 0 && end == str.size()-1) return;
if (start > end) {
std::string emp_str;
str = emp_str;
} else {
str = str.substr(start, end-start+1);
}
}
int Scorer::word_count(std::string sentence) {
strip(sentence);
int cnt = 1;
for (int i=0; i<sentence.size(); i++) {
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
cnt ++;
}
}
return cnt;
}
double Scorer::get_log_cond_prob(std::string sentence) {
lm::base::Model *model = (lm::base::Model *)this->_language_model;
State state, out_state;
lm::FullScoreReturn ret;
model->BeginSentenceWrite(&state);
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
ret = model->BaseFullScore(&state, wid, &out_state);
state = out_state;
}
//log10 prob
double log_prob = ret.prob;
return log_prob;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
double Scorer::get_score(std::string sentence, bool log) {
double lm_score = get_log_cond_prob(sentence);
int word_cnt = word_count(sentence);
double final_score = 0.0;
if (log == false) {
final_score = pow(10, alpha * lm_score) * pow(word_cnt, beta);
} else {
final_score = alpha * lm_score * std::log(10)
+ beta * std::log(word_cnt);
}
return final_score;
}

@ -30,6 +30,7 @@ public:
// Example:
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_log_cond_prob("this a sentence");
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{
public:
@ -40,7 +41,14 @@ public:
size_t get_max_order() { return _max_order; }
bool is_character_based() { return _is_character_based; }
std::vector<std::string> get_vocab() { return _vocabulary; }
// word insertion term
int word_count(std::string);
// get the log cond prob of the last word
double get_log_cond_prob(std::string);
// reset params alpha & beta
void reset_params(float alpha, float beta);
// get the final score
double get_score(std::string, bool log=false);
// expose to decoder
double alpha;
double beta;

Loading…
Cancel
Save