parent
d1189a7950
commit
3ee020397c
@ -1,15 +1,32 @@
|
||||
#ifndef DECODER_UTILS_H
|
||||
#define DECODER_UTILS_H
|
||||
#pragma once
|
||||
#ifndef DECODER_UTILS_H_
|
||||
#define DECODER_UTILS_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
/*
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b);
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
|
||||
{
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b);
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b)
|
||||
{
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T log_sum_exp(const T &x, const T &y)
|
||||
{
|
||||
static T num_min = -std::numeric_limits<T>::max();
|
||||
if (x <= num_min) return y;
|
||||
if (y <= num_min) return x;
|
||||
T xmax = std::max(x, y);
|
||||
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
|
||||
}
|
||||
|
||||
// Get length of utf8 encoding string
|
||||
// See: http://stackoverflow.com/a/4063229
|
||||
size_t get_utf8_str_len(const std::string& str);
|
||||
|
||||
template <typename T> T log_sum_exp(T x, T y);
|
||||
*/
|
||||
#endif // DECODER_UTILS_H
|
||||
|
@ -1,103 +1,89 @@
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include "scorer.h"
|
||||
#include "lm/model.hh"
|
||||
#include "util/tokenize_piece.hh"
|
||||
#include "util/string_piece.hh"
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
Scorer::Scorer(float alpha, float beta, std::string lm_model_path) {
|
||||
this->_alpha = alpha;
|
||||
this->_beta = beta;
|
||||
|
||||
if (access(lm_model_path.c_str(), F_OK) != 0) {
|
||||
std::cout<<"Invalid language model path!"<<std::endl;
|
||||
exit(1);
|
||||
}
|
||||
this->_language_model = LoadVirtual(lm_model_path.c_str());
|
||||
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
_is_character_based = true;
|
||||
_language_model = nullptr;
|
||||
_max_order = 0;
|
||||
// load language model
|
||||
load_LM(lm_path.c_str());
|
||||
}
|
||||
|
||||
Scorer::~Scorer(){
|
||||
delete (lm::base::Model *)this->_language_model;
|
||||
Scorer::~Scorer() {
|
||||
if (_language_model != nullptr)
|
||||
delete static_cast<lm::base::Model*>(_language_model);
|
||||
}
|
||||
|
||||
/* Strip a input sentence
|
||||
* Parameters:
|
||||
* str: A reference to the objective string
|
||||
* ch: The character to prune
|
||||
* Return:
|
||||
* void
|
||||
*/
|
||||
inline void strip(std::string &str, char ch=' ') {
|
||||
if (str.size() == 0) return;
|
||||
int start = 0;
|
||||
int end = str.size()-1;
|
||||
for (int i=0; i<str.size(); i++){
|
||||
if (str[i] == ch) {
|
||||
start ++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
void Scorer::load_LM(const char* filename) {
|
||||
if (access(filename, F_OK) != 0) {
|
||||
std::cerr << "Invalid language model file !!!" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
for (int i=str.size()-1; i>=0; i--) {
|
||||
if (str[i] == ch) {
|
||||
end --;
|
||||
} else {
|
||||
break;
|
||||
RetriveStrEnumerateVocab enumerate;
|
||||
Config config;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
_language_model = lm::ngram::LoadVirtual(filename, config);
|
||||
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
|
||||
_vocabulary = enumerate.vocabulary;
|
||||
for (size_t i = 0; i < _vocabulary.size(); ++i) {
|
||||
if (_is_character_based
|
||||
&& _vocabulary[i] != UNK_TOKEN
|
||||
&& _vocabulary[i] != START_TOKEN
|
||||
&& _vocabulary[i] != END_TOKEN
|
||||
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||
_is_character_based = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (start == 0 && end == str.size()-1) return;
|
||||
if (start > end) {
|
||||
std::string emp_str;
|
||||
str = emp_str;
|
||||
} else {
|
||||
str = str.substr(start, end-start+1);
|
||||
}
|
||||
}
|
||||
|
||||
int Scorer::word_count(std::string sentence) {
|
||||
strip(sentence);
|
||||
int cnt = 1;
|
||||
for (int i=0; i<sentence.size(); i++) {
|
||||
if (sentence[i] == ' ' && sentence[i-1] != ' ') {
|
||||
cnt ++;
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
|
||||
double cond_prob;
|
||||
State state, tmp_state, out_state;
|
||||
// avoid to inserting <s> in begin
|
||||
model->NullContextWrite(&state);
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||
// encounter OOV
|
||||
if (word_index == 0) {
|
||||
return OOV_SCOER;
|
||||
}
|
||||
}
|
||||
return cnt;
|
||||
}
|
||||
|
||||
double Scorer::language_model_score(std::string sentence) {
|
||||
lm::base::Model *model = (lm::base::Model *)this->_language_model;
|
||||
State state, out_state;
|
||||
lm::FullScoreReturn ret;
|
||||
model->BeginSentenceWrite(&state);
|
||||
|
||||
for (util::TokenIter<util::SingleCharacter, true> it(sentence, ' '); it; ++it){
|
||||
lm::WordIndex wid = model->BaseVocabulary().Index(*it);
|
||||
ret = model->BaseFullScore(&state, wid, &out_state);
|
||||
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||
tmp_state = state;
|
||||
state = out_state;
|
||||
out_state = tmp_state;
|
||||
}
|
||||
//log10 prob
|
||||
double log_prob = ret.prob;
|
||||
return log_prob;
|
||||
// log10 prob
|
||||
return cond_prob;
|
||||
}
|
||||
|
||||
void Scorer::reset_params(float alpha, float beta) {
|
||||
this->_alpha = alpha;
|
||||
this->_beta = beta;
|
||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||
std::vector<std::string> sentence;
|
||||
if (words.size() == 0) {
|
||||
for (size_t i = 0; i < _max_order; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < _max_order - 1; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||
}
|
||||
sentence.push_back(END_TOKEN);
|
||||
return get_log_prob(sentence);
|
||||
}
|
||||
|
||||
double Scorer::get_score(std::string sentence, bool log) {
|
||||
double lm_score = language_model_score(sentence);
|
||||
int word_cnt = word_count(sentence);
|
||||
|
||||
double final_score = 0.0;
|
||||
if (log == false) {
|
||||
final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
|
||||
} else {
|
||||
final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt);
|
||||
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||
assert(words.size() > _max_order);
|
||||
double score = 0.0;
|
||||
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
|
||||
std::vector<std::string> ngram(words.begin() + i,
|
||||
words.begin() + i + _max_order);
|
||||
score += get_log_cond_prob(ngram);
|
||||
}
|
||||
return final_score;
|
||||
return score;
|
||||
}
|
||||
|
Loading…
Reference in new issue