streamline source code

pull/2/head
Yibing Liu 7 years ago
parent 8ff6221d00
commit 9a79b41bcd

@ -10,8 +10,6 @@
#include "path_trie.h" #include "path_trie.h"
#include "ThreadPool.h" #include "ThreadPool.h"
typedef float log_prob_type;
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) std::vector<std::string> vocabulary)
{ {
@ -19,8 +17,8 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) { if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match" std::cout << "The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl; << " with the shape of the vocabulary!" << std::endl;
exit(1); exit(1);
} }
} }
@ -30,8 +28,8 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<int> max_idx_vec; std::vector<int> max_idx_vec;
double max_prob = 0.0; double max_prob = 0.0;
int max_idx = 0; int max_idx = 0;
for (int i=0; i<num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
for (int j=0; j<probs_seq[i].size(); j++) { for (int j = 0; j < probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) { if (max_prob < probs_seq[i][j]) {
max_idx = j; max_idx = j;
max_prob = probs_seq[i][j]; max_prob = probs_seq[i][j];
@ -43,14 +41,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
} }
std::vector<int> idx_vec; std::vector<int> idx_vec;
for (int i=0; i<max_idx_vec.size(); i++) { for (int i = 0; i < max_idx_vec.size(); i++) {
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
std::string best_path_result; std::string best_path_result;
for (int i=0; i<idx_vec.size(); i++) { for (int i = 0; i < idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]]; best_path_result += vocabulary[idx_vec[i]];
} }
@ -68,8 +66,8 @@ std::vector<std::pair<double, std::string> >
{ {
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) { if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << " The shape of probs_seq does not match" std::cout << " The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl; << " with the shape of the vocabulary!" << std::endl;
exit(1); exit(1);
@ -86,19 +84,14 @@ std::vector<std::pair<double, std::string> >
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
vocabulary.end(), " "); vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
// if no space in vocabulary
if(space_id >= vocabulary.size()) { if(space_id >= vocabulary.size()) {
std::cout << " The character space is not in the vocabulary!"<<std::endl; space_id = -2;
exit(1);
} }
static log_prob_type POS_INF = std::numeric_limits<log_prob_type>::max();
static log_prob_type NEG_INF = -POS_INF;
static log_prob_type NUM_MIN = std::numeric_limits<log_prob_type>::min();
// init // init
PathTrie root; PathTrie root;
root._log_prob_b_prev = 0.0; root._score = root._log_prob_b_prev = 0.0;
root._score = 0.0;
std::vector<PathTrie*> prefixes; std::vector<PathTrie*> prefixes;
prefixes.push_back(&root); prefixes.push_back(&root);
@ -140,17 +133,17 @@ std::vector<std::pair<double, std::string> >
prob_idx.begin() + cutoff_len); prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<int, log_prob_type> > log_prob_idx; std::vector<std::pair<int, float> > log_prob_idx;
for (int i=0; i<cutoff_len; i++) { for (int i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, log_prob_type> log_prob_idx.push_back(std::pair<int, float>
(prob_idx[i].first, log(prob_idx[i].second + NUM_MIN))); (prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
} }
// loop over chars // loop over chars
for (int index = 0; index < log_prob_idx.size(); index++) { for (int index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
log_prob_type log_prob_c = log_prob_idx[index].second; float log_prob_c = log_prob_idx[index].second;
//log_prob_type log_probs_prev; //float log_probs_prev;
for (int i = 0; i < prefixes.size() && i<beam_size; i++) { for (int i = 0; i < prefixes.size() && i<beam_size; i++) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
@ -165,17 +158,16 @@ std::vector<std::pair<double, std::string> >
if (c == prefix->_character) { if (c == prefix->_character) {
prefix->_log_prob_nb_cur = log_sum_exp( prefix->_log_prob_nb_cur = log_sum_exp(
prefix->_log_prob_nb_cur, prefix->_log_prob_nb_cur,
log_prob_c + prefix->_log_prob_nb_prev log_prob_c + prefix->_log_prob_nb_prev);
);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c); auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) { if (prefix_new != nullptr) {
float log_p = NEG_INF; float log_p = -NUM_FLT_INF;
if (c == prefix->_character if (c == prefix->_character
&& prefix->_log_prob_b_prev > NEG_INF) { && prefix->_log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->_log_prob_b_prev; log_p = log_prob_c + prefix->_log_prob_b_prev;
} else if (c != prefix->_character) { } else if (c != prefix->_character) {
log_p = log_prob_c + prefix->_score; log_p = log_prob_c + prefix->_score;
@ -201,7 +193,6 @@ std::vector<std::pair<double, std::string> >
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += ext_scorer->beta;
} }
prefix_new->_log_prob_nb_cur = log_sum_exp( prefix_new->_log_prob_nb_cur = log_sum_exp(
prefix_new->_log_prob_nb_cur, log_p); prefix_new->_log_prob_nb_cur, log_p);
@ -273,7 +264,7 @@ std::vector<std::pair<double, std::string> >
} }
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string> > >
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split, std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size, int beam_size,
@ -292,12 +283,12 @@ std::vector<std::vector<std::pair<double, std::string>>>
// number of samples // number of samples
int batch_size = probs_split.size(); int batch_size = probs_split.size();
// dictionary init // dictionary init
if ( ext_scorer != nullptr) { if ( ext_scorer != nullptr
if (ext_scorer->_dictionary == nullptr) { && !ext_scorer->is_character_based()
// TODO: init dictionary && ext_scorer->_dictionary == nullptr) {
ext_scorer->set_char_map(vocabulary); // init dictionary
ext_scorer->fill_dictionary(true); ext_scorer->set_char_map(vocabulary);
} ext_scorer->fill_dictionary(true);
} }
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
@ -308,7 +299,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
); );
} }
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results; std::vector<std::vector<std::pair<double, std::string> > > batch_results;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }

@ -15,7 +15,7 @@ size_t get_utf8_str_len(const std::string& str) {
//Splits string into vector of strings representing //Splits string into vector of strings representing
//UTF-8 characters (not same as chars) //UTF-8 characters (not same as chars)
//------------------------------------------------------ //------------------------------------------------------
std::vector<std::string> UTF8_split(const std::string& str) std::vector<std::string> split_utf8_str(const std::string& str)
{ {
std::vector<std::string> result; std::vector<std::string> result;
std::string out_str; std::string out_str;
@ -37,6 +37,29 @@ std::vector<std::string> UTF8_split(const std::string& str)
return result; return result;
} }
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
//------------------------------------------------------- //-------------------------------------------------------
// Overriding less than operator for sorting // Overriding less than operator for sorting
//------------------------------------------------------- //-------------------------------------------------------
@ -80,7 +103,7 @@ bool add_word_to_dictionary(const std::string& word,
bool add_space, bool add_space,
int SPACE, int SPACE,
fst::StdVectorFst* dictionary) { fst::StdVectorFst* dictionary) {
auto characters = UTF8_split(word); auto characters = split_utf8_str(word);
std::vector<int> int_word; std::vector<int> int_word;

@ -4,14 +4,19 @@
#include <utility> #include <utility>
#include "path_trie.h" #include "path_trie.h"
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b) bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b)
{ {
return a.first > b.first; return a.first > b.first;
} }
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, const std::pair<T1, T2> &b) bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b)
{ {
return a.second > b.second; return a.second > b.second;
} }
@ -26,16 +31,18 @@ T log_sum_exp(const T &x, const T &y)
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
} }
//-------------------------------------------------------
// Overriding less than operator for sorting // Functor for prefix comparsion
//-------------------------------------------------------
bool prefix_compare(const PathTrie* x, const PathTrie* y); bool prefix_compare(const PathTrie* x, const PathTrie* y);
// Get length of utf8 encoding string // Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229 // See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str); size_t get_utf8_str_len(const std::string& str);
std::vector<std::string> UTF8_split(const std::string &str); std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
std::vector<std::string> split_utf8_str(const std::string &str);
void add_word_to_fst(const std::vector<int>& word, void add_word_to_fst(const std::vector<int>& word,
fst::StdVectorFst* dictionary); fst::StdVectorFst* dictionary);

@ -8,12 +8,11 @@
#include "decoder_utils.h" #include "decoder_utils.h"
PathTrie::PathTrie() { PathTrie::PathTrie() {
float lowest = -1.0*std::numeric_limits<float>::max(); _log_prob_b_prev = -NUM_FLT_INF;
_log_prob_b_prev = lowest; _log_prob_nb_prev = -NUM_FLT_INF;
_log_prob_nb_prev = lowest; _log_prob_b_cur = -NUM_FLT_INF;
_log_prob_b_cur = lowest; _log_prob_nb_cur = -NUM_FLT_INF;
_log_prob_nb_cur = lowest; _score = -NUM_FLT_INF;
_score = lowest;
_ROOT = -1; _ROOT = -1;
_character = _ROOT; _character = _ROOT;
@ -41,11 +40,10 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
if ( child != _children.end() ) { if ( child != _children.end() ) {
if (!child->second->_exists) { if (!child->second->_exists) {
child->second->_exists = true; child->second->_exists = true;
float lowest = -1.0*std::numeric_limits<float>::max(); child->second->_log_prob_b_prev = -NUM_FLT_INF;
child->second->_log_prob_b_prev = lowest; child->second->_log_prob_nb_prev = -NUM_FLT_INF;
child->second->_log_prob_nb_prev = lowest; child->second->_log_prob_b_cur = -NUM_FLT_INF;
child->second->_log_prob_b_cur = lowest; child->second->_log_prob_nb_cur = -NUM_FLT_INF;
child->second->_log_prob_nb_cur = lowest;
} }
return (child->second); return (child->second);
} else { } else {
@ -106,8 +104,8 @@ void PathTrie::iterate_to_vec(
_log_prob_b_prev = _log_prob_b_cur; _log_prob_b_prev = _log_prob_b_cur;
_log_prob_nb_prev = _log_prob_nb_cur; _log_prob_nb_prev = _log_prob_nb_cur;
_log_prob_b_cur = -1.0 * std::numeric_limits<float>::max(); _log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -1.0 * std::numeric_limits<float>::max(); _log_prob_nb_cur = -NUM_FLT_INF;
_score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); _score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev);
output.push_back(this); output.push_back(this);
@ -117,9 +115,6 @@ void PathTrie::iterate_to_vec(
} }
} }
//-------------------------------------------------------
// Effectively removes node
//-------------------------------------------------------
void PathTrie::remove() { void PathTrie::remove() {
_exists = false; _exists = false;

@ -17,7 +17,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
_language_model = nullptr; _language_model = nullptr;
_dictionary = nullptr; _dictionary = nullptr;
_max_order = 0; _max_order = 0;
_SPACE = -1; _SPACE_ID = -1;
// load language model // load language model
load_LM(lm_path.c_str()); load_LM(lm_path.c_str());
} }
@ -61,7 +61,7 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV // encounter OOV
if (word_index == 0) { if (word_index == 0) {
return OOV_SCOER; return OOV_SCORE;
} }
cond_prob = model->BaseScore(&state, word_index, &out_state); cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state; tmp_state = state;
@ -197,64 +197,27 @@ Scorer::split_labels(const std::vector<int> &labels) {
std::string s = vec2str(labels); std::string s = vec2str(labels);
std::vector<std::string> words; std::vector<std::string> words;
if (_is_character_based) { if (_is_character_based) {
words = UTF8_split(s); words = split_utf8_str(s);
} else { } else {
words = split_str(s, " "); words = split_str(s, " ");
} }
return words; return words;
} }
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std::vector<std::string> Scorer::split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
//---------------------------------------------------
// Add index to char list for searching language model
//---------------------------------------------------
void Scorer::set_char_map(std::vector<std::string> char_list) { void Scorer::set_char_map(std::vector<std::string> char_list) {
_char_list = char_list; _char_list = char_list;
std::string _SPACE_STR = " ";
for (unsigned int i = 0; i < _char_list.size(); i++) {
// if (_char_list[i] == _BLANK_STR) {
// _BLANK = i;
// } else
if (_char_list[i] == _SPACE_STR) {
_SPACE = i;
}
}
_char_map.clear(); _char_map.clear();
for(unsigned int i = 0; i < _char_list.size(); i++) for(unsigned int i = 0; i < _char_list.size(); i++)
{ {
if(i == (unsigned int)_SPACE){ if (_char_list[i] == " ") {
_SPACE_ID = i;
_char_map[' '] = i; _char_map[' '] = i;
} } else if(_char_list[i].size() == 1){
else if(_char_list[i].size() == 1){
_char_map[_char_list[i][0]] = i; _char_map[_char_list[i][0]] = i;
} }
} }
}
} //------------- End of set_char_map ----------------
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) { std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram; std::vector<std::string> ngram;
@ -265,10 +228,10 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<int> prefix_vec; std::vector<int> prefix_vec;
if (_is_character_based) { if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, _SPACE, 1); new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
current_node = new_node; current_node = new_node;
} else { } else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE); new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
current_node = new_node->_parent; // Skipping spaces current_node = new_node->_parent; // Skipping spaces
} }
@ -279,7 +242,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if (new_node->_character == -1) { if (new_node->_character == -1) {
// No more spaces, but still need order // No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) { for (int i = 0; i < _max_order - order - 1; i++) {
ngram.push_back("<s>"); ngram.push_back(START_TOKEN);
} }
break; break;
} }
@ -288,10 +251,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
return ngram; return ngram;
} }
//---------------------------------------------------------
// Helper function to populate Trie with a vocab using the
// char_list for maping from string to int
//---------------------------------------------------------
void Scorer::fill_dictionary(bool add_space) { void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary; fst::StdVectorFst dictionary;
@ -307,7 +266,7 @@ void Scorer::fill_dictionary(bool add_space) {
bool added = add_word_to_dictionary(word, bool added = add_word_to_dictionary(word,
char_map, char_map,
add_space, add_space,
_SPACE, _SPACE_ID,
&dictionary); &dictionary);
vocab_size += added ? 1 : 0; vocab_size += added ? 1 : 0;
} }

@ -11,7 +11,7 @@
#include "util/string_piece.hh" #include "util/string_piece.hh"
#include "path_trie.h" #include "path_trie.h"
const double OOV_SCOER = -1000.0; const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>"; const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>"; const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>"; const std::string END_TOKEN = "</s>";
@ -68,18 +68,13 @@ protected:
double get_log_prob(const std::vector<std::string>& words); double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int> &input);
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int> &labels);
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
private: private:
void _init_char_list();
void _init_char_map();
void* _language_model; void* _language_model;
bool _is_character_based; bool _is_character_based;
size_t _max_order; size_t _max_order;
unsigned int _SPACE; int _SPACE_ID;
std::vector<std::string> _char_list; std::vector<std::string> _char_list;
std::unordered_map<char, int> _char_map; std::unordered_map<char, int> _char_map;

Loading…
Cancel
Save