format C++ source code

pull/2/head
Yibing Liu 7 years ago
parent a2ddfe8d9e
commit 5208b8e40f

@ -1,22 +1,21 @@
#include <iostream> #include "ctc_decoders.h"
#include <map>
#include <algorithm> #include <algorithm>
#include <utility>
#include <cmath> #include <cmath>
#include <iostream>
#include <limits> #include <limits>
#include "fst/fstlib.h" #include <map>
#include "ctc_decoders.h" #include <utility>
#include "ThreadPool.h"
#include "decoder_utils.h" #include "decoder_utils.h"
#include "fst/fstlib.h"
#include "path_trie.h" #include "path_trie.h"
#include "ThreadPool.h"
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary) std::vector<std::string> vocabulary) {
{
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) { if (probs_seq[i].size() != vocabulary.size() + 1) {
std::cout << "The shape of probs_seq does not match" std::cout << "The shape of probs_seq does not match"
<< " with the shape of the vocabulary!" << std::endl; << " with the shape of the vocabulary!" << std::endl;
exit(1); exit(1);
@ -42,7 +41,7 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<int> idx_vec; std::vector<int> idx_vec;
for (int i = 0; i < max_idx_vec.size(); i++) { for (int i = 0; i < max_idx_vec.size(); i++) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
@ -56,15 +55,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
return best_path_result; return best_path_result;
} }
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, std::vector<std::vector<double>> probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob, double cutoff_prob,
int cutoff_top_n, int cutoff_top_n,
Scorer *ext_scorer) Scorer *extscorer) {
{
// dimension check // dimension check
int num_time_steps = probs_seq.size(); int num_time_steps = probs_seq.size();
for (int i = 0; i < num_time_steps; i++) { for (int i = 0; i < num_time_steps; i++) {
@ -82,35 +80,33 @@ std::vector<std::pair<double, std::string> >
} }
// assign space ID // assign space ID
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), std::vector<std::string>::iterator it =
vocabulary.end(), " "); std::find(vocabulary.begin(), vocabulary.end(), " ");
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
// if no space in vocabulary // if no space in vocabulary
if(space_id >= vocabulary.size()) { if (space_id >= vocabulary.size()) {
space_id = -2; space_id = -2;
} }
// init prefixes' root // init prefixes' root
PathTrie root; PathTrie root;
root._score = root._log_prob_b_prev = 0.0; root.score = root.log_prob_b_prev = 0.0;
std::vector<PathTrie*> prefixes; std::vector<PathTrie *> prefixes;
prefixes.push_back(&root); prefixes.push_back(&root);
if ( ext_scorer != nullptr) { if (extscorer != nullptr) {
if (ext_scorer->is_char_map_empty()) { if (extscorer->is_char_map_empty()) {
ext_scorer->set_char_map(vocabulary); extscorer->set_char_map(vocabulary);
} }
if (!ext_scorer->is_character_based()) { if (!extscorer->is_character_based()) {
if (ext_scorer->dictionary == nullptr) { if (extscorer->dictionary == nullptr) {
// fill dictionary for fst // fill dictionary for fst
ext_scorer->fill_dictionary(true); extscorer->fill_dictionary(true);
} }
auto fst_dict = static_cast<fst::StdVectorFst*> auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary);
(ext_scorer->dictionary); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr); root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH> auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher); root.set_matcher(matcher);
} }
} }
@ -118,45 +114,44 @@ std::vector<std::pair<double, std::string> >
// prefix search over time // prefix search over time
for (int time_step = 0; time_step < num_time_steps; time_step++) { for (int time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step]; std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double> > prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (int i=0; i<prob.size(); i++) { for (int i = 0; i < prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i])); prob_idx.push_back(std::pair<int, double>(i, prob[i]));
} }
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
if (ext_scorer != nullptr) { if (extscorer != nullptr) {
int num_prefixes = std::min((int)prefixes.size(), beam_size); int num_prefixes = std::min((int)prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, std::sort(
prefix_compare); prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
min_cutoff = prefixes[num_prefixes-1]->_score + log(prob[blank_id]) min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) -
- std::max(0.0, ext_scorer->beta); std::max(0.0, extscorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
// pruning of vacobulary // pruning of vacobulary
int cutoff_len = prob.size(); int cutoff_len = prob.size();
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
std::sort(prob_idx.begin(), std::sort(
prob_idx.end(), prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) { for (int i = 0; i < prob_idx.size(); i++) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
} }
} }
cutoff_len = std::min(cutoff_len, cutoff_top_n); cutoff_len = std::min(cutoff_len, cutoff_top_n);
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(), prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin() + cutoff_len); prob_idx.begin(), prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<int, float> > log_prob_idx; std::vector<std::pair<int, float>> log_prob_idx;
for (int i = 0; i < cutoff_len; i++) { for (int i = 0; i < cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, float> log_prob_idx.push_back(std::pair<int, float>(
(prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
} }
// loop over chars // loop over chars
@ -164,24 +159,22 @@ std::vector<std::pair<double, std::string> >
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second; float log_prob_c = log_prob_idx[index].second;
for (int i = 0; i < prefixes.size() && i<beam_size; i++) { for (int i = 0; i < prefixes.size() && i < beam_size; i++) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->_score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break; break;
} }
// blank // blank
if (c == blank_id) { if (c == blank_id) {
prefix->_log_prob_b_cur = log_sum_exp( prefix->log_prob_b_cur =
prefix->_log_prob_b_cur, log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
log_prob_c + prefix->_score);
continue; continue;
} }
// repeated character // repeated character
if (c == prefix->_character) { if (c == prefix->character) {
prefix->_log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur = log_sum_exp(
prefix->_log_prob_nb_cur, prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
log_prob_c + prefix->_log_prob_nb_prev);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c); auto prefix_new = prefix->get_path_trie(c);
@ -189,36 +182,35 @@ std::vector<std::pair<double, std::string> >
if (prefix_new != nullptr) { if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF; float log_p = -NUM_FLT_INF;
if (c == prefix->_character if (c == prefix->character &&
&& prefix->_log_prob_b_prev > -NUM_FLT_INF) { prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->_log_prob_b_prev; log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->_character) { } else if (c != prefix->character) {
log_p = log_prob_c + prefix->_score; log_p = log_prob_c + prefix->score;
} }
// language model scoring // language model scoring
if (ext_scorer != nullptr && if (extscorer != nullptr &&
(c == space_id || ext_scorer->is_character_based()) ) { (c == space_id || extscorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr; PathTrie *prefix_toscore = nullptr;
// skip scoring the space // skip scoring the space
if (ext_scorer->is_character_based()) { if (extscorer->is_character_based()) {
prefix_to_score = prefix_new; prefix_toscore = prefix_new;
} else { } else {
prefix_to_score = prefix; prefix_toscore = prefix;
} }
double score = 0.0; double score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score); ngram = extscorer->make_ngram(prefix_toscore);
score = ext_scorer->get_log_cond_prob(ngram) * score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha;
ext_scorer->alpha;
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += extscorer->beta;
} }
prefix_new->_log_prob_nb_cur = log_sum_exp( prefix_new->log_prob_nb_cur =
prefix_new->_log_prob_nb_cur, log_p); log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
} }
} // end of loop over prefix } // end of loop over prefix
} // end of loop over chars } // end of loop over chars
@ -242,33 +234,32 @@ std::vector<std::pair<double, std::string> >
// compute aproximate ctc score as the return score // compute aproximate ctc score as the return score
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
double approx_ctc = prefixes[i]->_score; double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) { if (extscorer != nullptr) {
std::vector<int> output; std::vector<int> output;
prefixes[i]->get_path_vec(output); prefixes[i]->get_path_vec(output);
size_t prefix_length = output.size(); size_t prefix_length = output.size();
auto words = ext_scorer->split_labels(output); auto words = extscorer->split_labels(output);
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; approx_ctc = approx_ctc - prefix_length * extscorer->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha;
* ext_scorer->alpha;
} }
prefixes[i]->_approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
} }
// allow for the post processing // allow for the post processing
std::vector<PathTrie*> space_prefixes; std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) { if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
space_prefixes.push_back(prefixes[i]); space_prefixes.push_back(prefixes[i]);
} }
} }
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string> > output_vecs; std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
std::vector<int> output; std::vector<int> output;
space_prefixes[i]->get_path_vec(output); space_prefixes[i]->get_path_vec(output);
@ -277,17 +268,16 @@ std::vector<std::pair<double, std::string> >
for (int j = 0; j < output.size(); j++) { for (int j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]]; output_str += vocabulary[output[j]];
} }
std::pair<double, std::string> std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_pair(-space_prefixes[i]->_approx_ctc, output_str); output_str);
output_vecs.emplace_back(output_pair); output_vecs.emplace_back(output_pair);
} }
return output_vecs; return output_vecs;
} }
std::vector<std::vector<std::pair<double, std::string> > > std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split, std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
@ -295,8 +285,7 @@ std::vector<std::vector<std::pair<double, std::string> > >
int num_processes, int num_processes,
double cutoff_prob, double cutoff_prob,
int cutoff_top_n, int cutoff_top_n,
Scorer *ext_scorer Scorer *extscorer) {
) {
if (num_processes <= 0) { if (num_processes <= 0) {
std::cout << "num_processes must be nonnegative!" << std::endl; std::cout << "num_processes must be nonnegative!" << std::endl;
exit(1); exit(1);
@ -307,29 +296,32 @@ std::vector<std::vector<std::pair<double, std::string> > >
int batch_size = probs_split.size(); int batch_size = probs_split.size();
// scorer filling up // scorer filling up
if ( ext_scorer != nullptr) { if (extscorer != nullptr) {
if (ext_scorer->is_char_map_empty()) { if (extscorer->is_char_map_empty()) {
ext_scorer->set_char_map(vocabulary); extscorer->set_char_map(vocabulary);
} }
if(!ext_scorer->is_character_based() if (!extscorer->is_character_based() &&
&& ext_scorer->dictionary == nullptr) { extscorer->dictionary == nullptr) {
// init dictionary // init dictionary
ext_scorer->fill_dictionary(true); extscorer->fill_dictionary(true);
} }
} }
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
res.emplace_back( res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
pool.enqueue(ctc_beam_search_decoder, probs_split[i], probs_split[i],
beam_size, vocabulary, blank_id, cutoff_prob, beam_size,
cutoff_top_n, ext_scorer) vocabulary,
); blank_id,
cutoff_prob,
cutoff_top_n,
extscorer));
} }
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string> > > batch_results; std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }

@ -1,9 +1,9 @@
#ifndef CTC_BEAM_SEARCH_DECODER_H_ #ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_
#include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "scorer.h" #include "scorer.h"
/* CTC Best Path Decoder /* CTC Best Path Decoder
@ -16,7 +16,7 @@
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq, std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
std::vector<std::string> vocabulary); std::vector<std::string> vocabulary);
/* CTC Beam Search Decoder /* CTC Beam Search Decoder
@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, std::vector<std::vector<double>> probs_seq,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
double cutoff_prob=1.0, double cutoff_prob = 1.0,
int cutoff_top_n=40, int cutoff_top_n = 40,
Scorer *ext_scorer=NULL Scorer *ext_scorer = NULL);
);
/* CTC Beam Search Decoder for batch data, the interface is consistent with the /* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version. * original decoder in Python version.
@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
* sample. * sample.
*/ */
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(std::vector<std::vector<std::vector<double>>> probs_split, ctc_beam_search_decoder_batch(
std::vector<std::vector<std::vector<double>>> probs_split,
int beam_size, int beam_size,
std::vector<std::string> vocabulary, std::vector<std::string> vocabulary,
int blank_id, int blank_id,
int num_processes, int num_processes,
double cutoff_prob=1.0, double cutoff_prob = 1.0,
int cutoff_top_n=40, int cutoff_top_n = 40,
Scorer *ext_scorer=NULL Scorer *ext_scorer = NULL);
);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

@ -1,7 +1,7 @@
#include <limits> #include "decoder_utils.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include "decoder_utils.h" #include <limits>
size_t get_utf8_str_len(const std::string& str) { size_t get_utf8_str_len(const std::string& str) {
size_t str_len = 0; size_t str_len = 0;
@ -11,17 +11,14 @@ size_t get_utf8_str_len(const std::string& str) {
return str_len; return str_len;
} }
std::vector<std::string> split_utf8_str(const std::string& str) std::vector<std::string> split_utf8_str(const std::string& str) {
{
std::vector<std::string> result; std::vector<std::string> result;
std::string out_str; std::string out_str;
for (char c : str) for (char c : str) {
{ if ((c & 0xc0) != 0x80) // new UTF-8 character
if ((c & 0xc0) != 0x80) //new UTF-8 character
{
if (!out_str.empty())
{ {
if (!out_str.empty()) {
result.push_back(out_str); result.push_back(out_str);
out_str.clear(); out_str.clear();
} }
@ -33,8 +30,8 @@ std::vector<std::string> split_utf8_str(const std::string& str)
return result; return result;
} }
std::vector<std::string> split_str(const std::string &s, std::vector<std::string> split_str(const std::string& s,
const std::string &delim) { const std::string& delim) {
std::vector<std::string> result; std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size(); std::size_t start = 0, delim_len = delim.size();
while (true) { while (true) {
@ -54,14 +51,14 @@ std::vector<std::string> split_str(const std::string &s,
} }
bool prefix_compare(const PathTrie* x, const PathTrie* y) { bool prefix_compare(const PathTrie* x, const PathTrie* y) {
if (x->_score == y->_score) { if (x->score == y->score) {
if (x->_character == y->_character) { if (x->character == y->character) {
return false; return false;
} else { } else {
return (x->_character < y->_character); return (x->character < y->character);
} }
} else { } else {
return x->_score > y->_score; return x->score > y->score;
} }
} }
@ -82,7 +79,8 @@ void add_word_to_fst(const std::vector<int>& word,
dictionary->SetFinal(dst, fst::StdArc::Weight::One()); dictionary->SetFinal(dst, fst::StdArc::Weight::One());
} }
bool add_word_to_dictionary(const std::string& word, bool add_word_to_dictionary(
const std::string& word,
const std::unordered_map<std::string, int>& char_map, const std::unordered_map<std::string, int>& char_map,
bool add_space, bool add_space,
int SPACE_ID, int SPACE_ID,

@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// Function template for comparing two pairs // Function template for comparing two pairs
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a, bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) const std::pair<T1, T2> &b) {
{
return a.first > b.first; return a.first > b.first;
} }
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a, bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) const std::pair<T1, T2> &b) {
{
return a.second > b.second; return a.second > b.second;
} }
template <typename T> template <typename T>
T log_sum_exp(const T &x, const T &y) T log_sum_exp(const T &x, const T &y) {
{
static T num_min = -std::numeric_limits<T>::max(); static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y; if (x <= num_min) return y;
if (y <= num_min) return x; if (y <= num_min) return x;
T xmax = std::max(x, y); T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
} }
// Functor for prefix comparsion // Functor for prefix comparsion
bool prefix_compare(const PathTrie* x, const PathTrie* y); bool prefix_compare(const PathTrie *x, const PathTrie *y);
// Get length of utf8 encoding string // Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229 // See: http://stackoverflow.com/a/4063229
size_t get_utf8_str_len(const std::string& str); size_t get_utf8_str_len(const std::string &str);
// Split a string into a list of strings on a given string // Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are // delimiter. NB: delimiters on beginning / end of string are
@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
std::vector<std::string> split_utf8_str(const std::string &str); std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst // Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int>& word, void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst* dictionary); fst::StdVectorFst *dictionary);
// Add a word in string to dictionary // Add a word in string to dictionary
bool add_word_to_dictionary(const std::string& word, bool add_word_to_dictionary(
const std::unordered_map<std::string, int>& char_map, const std::string &word,
const std::unordered_map<std::string, int> &char_map,
bool add_space, bool add_space,
int SPACE_ID, int SPACE_ID,
fst::StdVectorFst* dictionary); fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H #endif // DECODER_UTILS_H

@ -4,20 +4,20 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "path_trie.h"
#include "decoder_utils.h" #include "decoder_utils.h"
#include "path_trie.h"
PathTrie::PathTrie() { PathTrie::PathTrie() {
_log_prob_b_prev = -NUM_FLT_INF; log_prob_b_prev = -NUM_FLT_INF;
_log_prob_nb_prev = -NUM_FLT_INF; log_prob_nb_prev = -NUM_FLT_INF;
_log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
_score = -NUM_FLT_INF; score = -NUM_FLT_INF;
_ROOT = -1; _ROOT = -1;
_character = _ROOT; character = _ROOT;
_exists = true; _exists = true;
_parent = nullptr; parent = nullptr;
_dictionary = nullptr; _dictionary = nullptr;
_dictionary_state = 0; _dictionary_state = 0;
_has_dictionary = false; _has_dictionary = false;
@ -37,13 +37,13 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
break; break;
} }
} }
if ( child != _children.end() ) { if (child != _children.end()) {
if (!child->second->_exists) { if (!child->second->_exists) {
child->second->_exists = true; child->second->_exists = true;
child->second->_log_prob_b_prev = -NUM_FLT_INF; child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->_log_prob_nb_prev = -NUM_FLT_INF; child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->_log_prob_b_cur = -NUM_FLT_INF; child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->_log_prob_nb_cur = -NUM_FLT_INF; child->second->log_prob_nb_cur = -NUM_FLT_INF;
} }
return (child->second); return (child->second);
} else { } else {
@ -61,8 +61,8 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
return nullptr; return nullptr;
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->_character = new_char; new_path->character = new_char;
new_path->_parent = this; new_path->parent = this;
new_path->_dictionary = _dictionary; new_path->_dictionary = _dictionary;
new_path->_dictionary_state = _matcher->Value().nextstate; new_path->_dictionary_state = _matcher->Value().nextstate;
new_path->_has_dictionary = true; new_path->_has_dictionary = true;
@ -72,8 +72,8 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
} }
} else { } else {
PathTrie* new_path = new PathTrie; PathTrie* new_path = new PathTrie;
new_path->_character = new_char; new_path->character = new_char;
new_path->_parent = this; new_path->parent = this;
_children.push_back(std::make_pair(new_char, new_path)); _children.push_back(std::make_pair(new_char, new_path));
return new_path; return new_path;
} }
@ -87,27 +87,24 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps) { size_t max_steps) {
if (_character == stop || if (character == stop || character == _ROOT || output.size() == max_steps) {
_character == _ROOT ||
output.size() == max_steps) {
std::reverse(output.begin(), output.end()); std::reverse(output.begin(), output.end());
return this; return this;
} else { } else {
output.push_back(_character); output.push_back(character);
return _parent->get_path_vec(output, stop, max_steps); return parent->get_path_vec(output, stop, max_steps);
} }
} }
void PathTrie::iterate_to_vec( void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
std::vector<PathTrie*>& output) {
if (_exists) { if (_exists) {
_log_prob_b_prev = _log_prob_b_cur; log_prob_b_prev = log_prob_b_cur;
_log_prob_nb_prev = _log_prob_nb_cur; log_prob_nb_prev = log_prob_nb_cur;
_log_prob_b_cur = -NUM_FLT_INF; log_prob_b_cur = -NUM_FLT_INF;
_log_prob_nb_cur = -NUM_FLT_INF; log_prob_nb_cur = -NUM_FLT_INF;
_score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this); output.push_back(this);
} }
for (auto child : _children) { for (auto child : _children) {
@ -119,17 +116,17 @@ void PathTrie::remove() {
_exists = false; _exists = false;
if (_children.size() == 0) { if (_children.size() == 0) {
auto child = _parent->_children.begin(); auto child = parent->_children.begin();
for (child = _parent->_children.begin(); for (child = parent->_children.begin(); child != parent->_children.end();
child != _parent->_children.end(); ++child) { ++child) {
if (child->first == _character) { if (child->first == character) {
_parent->_children.erase(child); parent->_children.erase(child);
break; break;
} }
} }
if ( _parent->_children.size() == 0 && !_parent->_exists ) { if (parent->_children.size() == 0 && !parent->_exists) {
_parent->remove(); parent->remove();
} }
delete this; delete this;

@ -1,12 +1,12 @@
#ifndef PATH_TRIE_H #ifndef PATH_TRIE_H
#define PATH_TRIE_H #define PATH_TRIE_H
#pragma once #pragma once
#include <fst/fstlib.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <fst/fstlib.h>
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
@ -17,38 +17,36 @@ public:
PathTrie* get_path_trie(int new_char, bool reset = true); PathTrie* get_path_trie(int new_char, bool reset = true);
PathTrie* get_path_vec(std::vector<int> &output); PathTrie* get_path_vec(std::vector<int>& output);
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(std::vector<int>& output,
int stop, int stop,
size_t max_steps = std::numeric_limits<size_t>::max()); size_t max_steps = std::numeric_limits<size_t>::max());
void iterate_to_vec(std::vector<PathTrie*> &output); void iterate_to_vec(std::vector<PathTrie*>& output);
void set_dictionary(fst::StdVectorFst* dictionary); void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<FSTMATCH> matcher); void set_matcher(std::shared_ptr<FSTMATCH> matcher);
bool is_empty() { bool is_empty() { return _ROOT == character; }
return _ROOT == _character;
}
void remove(); void remove();
float _log_prob_b_prev; float log_prob_b_prev;
float _log_prob_nb_prev; float log_prob_nb_prev;
float _log_prob_b_cur; float log_prob_b_cur;
float _log_prob_nb_cur; float log_prob_nb_cur;
float _score; float score;
float _approx_ctc; float approx_ctc;
int character;
PathTrie* parent;
private:
int _ROOT; int _ROOT;
int _character;
bool _exists; bool _exists;
PathTrie *_parent; std::vector<std::pair<int, PathTrie*>> _children;
std::vector<std::pair<int, PathTrie*> > _children;
fst::StdVectorFst* _dictionary; fst::StdVectorFst* _dictionary;
fst::StdVectorFst::StateId _dictionary_state; fst::StdVectorFst::StateId _dictionary_state;

@ -1,12 +1,12 @@
#include <iostream> #include "scorer.h"
#include <unistd.h> #include <unistd.h>
#include <iostream>
#include "decoder_utils.h"
#include "lm/config.hh" #include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/tokenize_piece.hh" #include "lm/state.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
#include "scorer.h" #include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram; using namespace lm::ngram;
@ -25,8 +25,7 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
Scorer::~Scorer() { Scorer::~Scorer() {
if (_language_model != nullptr) if (_language_model != nullptr)
delete static_cast<lm::base::Model*>(_language_model); delete static_cast<lm::base::Model*>(_language_model);
if (dictionary != nullptr) if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary);
delete static_cast<fst::StdVectorFst*>(dictionary);
} }
void Scorer::load_LM(const char* filename) { void Scorer::load_LM(const char* filename) {
@ -41,11 +40,9 @@ void Scorer::load_LM(const char* filename) {
_max_order = static_cast<lm::base::Model*>(_language_model)->Order(); _max_order = static_cast<lm::base::Model*>(_language_model)->Order();
_vocabulary = enumerate.vocabulary; _vocabulary = enumerate.vocabulary;
for (size_t i = 0; i < _vocabulary.size(); ++i) { for (size_t i = 0; i < _vocabulary.size(); ++i) {
if (_is_character_based if (_is_character_based && _vocabulary[i] != UNK_TOKEN &&
&& _vocabulary[i] != UNK_TOKEN _vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN &&
&& _vocabulary[i] != START_TOKEN get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
&& _vocabulary[i] != END_TOKEN
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
_is_character_based = false; _is_character_based = false;
} }
} }
@ -112,10 +109,8 @@ std::string Scorer::vec2str(const std::vector<int>& input) {
return word; return word;
} }
std::vector<std::string> std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
Scorer::split_labels(const std::vector<int> &labels) { if (labels.empty()) return {};
if (labels.empty())
return {};
std::string s = vec2str(labels); std::string s = vec2str(labels);
std::vector<std::string> words; std::vector<std::string> words;
@ -131,12 +126,11 @@ void Scorer::set_char_map(std::vector<std::string> char_list) {
_char_list = char_list; _char_list = char_list;
_char_map.clear(); _char_map.clear();
for(unsigned int i = 0; i < _char_list.size(); i++) for (unsigned int i = 0; i < _char_list.size(); i++) {
{
if (_char_list[i] == " ") { if (_char_list[i] == " ") {
_SPACE_ID = i; _SPACE_ID = i;
_char_map[' '] = i; _char_map[' '] = i;
} else if(_char_list[i].size() == 1){ } else if (_char_list[i].size() == 1) {
_char_map[_char_list[i][0]] = i; _char_map[_char_list[i][0]] = i;
} }
} }
@ -155,14 +149,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
current_node = new_node; current_node = new_node;
} else { } else {
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
current_node = new_node->_parent; // Skipping spaces current_node = new_node->parent; // Skipping spaces
} }
// reconstruct word // reconstruct word
std::string word = vec2str(prefix_vec); std::string word = vec2str(prefix_vec);
ngram.push_back(word); ngram.push_back(word);
if (new_node->_character == -1) { if (new_node->character == -1) {
// No more spaces, but still need order // No more spaces, but still need order
for (int i = 0; i < _max_order - order - 1; i++) { for (int i = 0; i < _max_order - order - 1; i++) {
ngram.push_back(START_TOKEN); ngram.push_back(START_TOKEN);
@ -175,7 +169,6 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
} }
void Scorer::fill_dictionary(bool add_space) { void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary; fst::StdVectorFst dictionary;
// First reverse char_list so ints can be accessed by chars // First reverse char_list so ints can be accessed by chars
std::unordered_map<std::string, int> char_map; std::unordered_map<std::string, int> char_map;
@ -186,11 +179,8 @@ void Scorer::fill_dictionary(bool add_space) {
// For each unigram convert to ints and put in trie // For each unigram convert to ints and put in trie
int vocab_size = 0; int vocab_size = 0;
for (const auto& word : _vocabulary) { for (const auto& word : _vocabulary) {
bool added = add_word_to_dictionary(word, bool added = add_word_to_dictionary(
char_map, word, char_map, add_space, _SPACE_ID, &dictionary);
add_space,
_SPACE_ID,
&dictionary);
vocab_size += added ? 1 : 0; vocab_size += added ? 1 : 0;
} }
@ -215,5 +205,4 @@ void Scorer::fill_dictionary(bool add_space) {
// memory usage of the dictionary // memory usage of the dictionary
fst::Minimize(new_dict); fst::Minimize(new_dict);
this->dictionary = new_dict; this->dictionary = new_dict;
} }

@ -1,22 +1,22 @@
#ifndef SCORER_H_ #ifndef SCORER_H_
#define SCORER_H_ #define SCORER_H_
#include <string>
#include <memory> #include <memory>
#include <vector> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh" #include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh" #include "lm/virtual_interface.hh"
#include "util/string_piece.hh" #include "lm/word_index.hh"
#include "path_trie.h" #include "path_trie.h"
#include "util/string_piece.hh"
const double OOV_SCORE = -1000.0; const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>"; const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>"; const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>"; const std::string END_TOKEN = "</s>";
// Implement a callback to retrive string vocabulary. // Implement a callback to retrive string vocabulary.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab { class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public: public:
RetriveStrEnumerateVocab() {} RetriveStrEnumerateVocab() {}
@ -33,7 +33,7 @@ public:
// Scorer scorer(alpha, beta, "path_of_language_model"); // Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class Scorer{ class Scorer {
public: public:
Scorer(double alpha, double beta, const std::string& lm_path); Scorer(double alpha, double beta, const std::string& lm_path);
~Scorer(); ~Scorer();
@ -44,7 +44,7 @@ public:
size_t get_max_order() { return _max_order; } size_t get_max_order() { return _max_order; }
bool is_char_map_empty() {return _char_map.size() == 0; } bool is_char_map_empty() { return _char_map.size() == 0; }
bool is_character_based() { return _is_character_based; } bool is_character_based() { return _is_character_based; }
@ -60,7 +60,7 @@ public:
// set char map // set char map
void set_char_map(std::vector<std::string> char_list); void set_char_map(std::vector<std::string> char_list);
std::vector<std::string> split_labels(const std::vector<int> &labels); std::vector<std::string> split_labels(const std::vector<int>& labels);
// expose to decoder // expose to decoder
double alpha; double alpha;
@ -74,7 +74,7 @@ protected:
double get_log_prob(const std::vector<std::string>& words); double get_log_prob(const std::vector<std::string>& words);
std::string vec2str(const std::vector<int> &input); std::string vec2str(const std::vector<int>& input);
private: private:
void* _language_model; void* _language_model;

Loading…
Cancel
Save