change probs' computation into log scale & add best path decoder

pull/2/head
Yibing Liu 8 years ago
parent ccea7c0150
commit a840f85423

@ -3,8 +3,11 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <cmath> #include <cmath>
#include <limits>
#include "ctc_beam_search_decoder.h" #include "ctc_beam_search_decoder.h"
typedef float log_prob_type;
template <typename T1, typename T2> template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
{ {
@ -17,6 +20,65 @@ bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
return a.second > b.second; return a.second > b.second;
} }
template <typename T>
T log_sum_exp(T x, T y)
{
static T num_min = -std::numeric_limits<T>::max();
if (x <= -num_min) return y;
if (y <= -num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
}
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary) {
// dimension check
int num_time_steps = probs_seq.size();
for (int i=0; i<num_time_steps; i++) {
if (probs_seq[i].size() != vocabulary.size()+1) {
std::cout<<"The shape of probs_seq does not match"
<<" with the shape of the vocabulary!"<<std::endl;
exit(1);
}
}
int blank_id = vocabulary.size();
std::vector<int> max_idx_vec;
double max_prob = 0.0;
int max_idx = 0;
for (int i=0; i<num_time_steps; i++) {
for (int j=0; j<probs_seq[i].size(); j++) {
if (max_prob < probs_seq[i][j]) {
max_idx = j;
max_prob = probs_seq[i][j];
}
}
max_idx_vec.push_back(max_idx);
std::cout<<max_idx<<",";
max_prob = 0.0;
max_idx = 0;
}
std::cout<<std::endl;
std::vector<int> idx_vec;
for (int i=0; i<max_idx_vec.size(); i++) {
std::cout<<max_idx_vec[i]<<",";
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
std::cout<<max_idx_vec[i]<<",";
idx_vec.push_back(max_idx_vec[i]);
}
}
std::string best_path_result;
for (int i=0; i<idx_vec.size(); i++) {
if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[i];
}
}
return best_path_result;
}
std::vector<std::pair<double, std::string> > std::vector<std::pair<double, std::string> >
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq, ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
int beam_size, int beam_size,
@ -52,106 +114,147 @@ std::vector<std::pair<double, std::string> >
// initialize // initialize
// two sets containing selected and candidate prefixes respectively // two sets containing selected and candidate prefixes respectively
std::map<std::string, double> prefix_set_prev, prefix_set_next; std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
// probability of prefixes ending with blank and non-blank // probability of prefixes ending with blank and non-blank
std::map<std::string, double> probs_b_prev, probs_nb_prev; std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
std::map<std::string, double> probs_b_cur, probs_nb_cur; std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;
prefix_set_prev["\t"] = 1.0;
probs_b_prev["\t"] = 1.0; static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
probs_nb_prev["\t"] = 0.0; prefix_set_prev["\t"] = 0.0;
log_probs_b_prev["\t"] = 0.0;
log_probs_nb_prev["\t"] = -NUM_MAX;
for (int time_step=0; time_step<num_time_steps; time_step++) { for (int time_step=0; time_step<num_time_steps; time_step++) {
prefix_set_next.clear(); prefix_set_next.clear();
probs_b_cur.clear(); log_probs_b_cur.clear();
probs_nb_cur.clear(); log_probs_nb_cur.clear();
std::vector<double> prob = probs_seq[time_step]; std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double> > prob_idx; std::vector<std::pair<int, double> > prob_idx;
for (int i=0; i<prob.size(); i++) { for (int i=0; i<prob.size(); i++) {
prob_idx.push_back(std::pair<int, double>(i, prob[i])); prob_idx.push_back(std::pair<int, double>(i, prob[i]));
} }
// pruning of vacobulary // pruning of vacobulary
int cutoff_len = prob.size();
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
std::sort(prob_idx.begin(), prob_idx.end(), std::sort(prob_idx.begin(),
prob_idx.end(),
pair_comp_second_rev<int, double>); pair_comp_second_rev<int, double>);
float cum_prob = 0.0; double cum_prob = 0.0;
int cutoff_len = 0; cutoff_len = 0;
for (int i=0; i<prob_idx.size(); i++) { for (int i=0; i<prob_idx.size(); i++) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
} }
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(), prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
prob_idx.begin() + cutoff_len); prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<int, log_prob_type> > log_prob_idx;
for (int i=0; i<cutoff_len; i++) {
log_prob_idx.push_back(std::pair<int, log_prob_type>
(prob_idx[i].first, log(prob_idx[i].second)));
}
// extend prefix // extend prefix
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin(); for (std::map<std::string, log_prob_type>::iterator
it = prefix_set_prev.begin();
it != prefix_set_prev.end(); it++) { it != prefix_set_prev.end(); it++) {
std::string l = it->first; std::string l = it->first;
if( prefix_set_next.find(l) == prefix_set_next.end()) { if( prefix_set_next.find(l) == prefix_set_next.end()) {
probs_b_cur[l] = probs_nb_cur[l] = 0.0; log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
} }
for (int index=0; index<prob_idx.size(); index++) { for (int index=0; index<log_prob_idx.size(); index++) {
int c = prob_idx[index].first; int c = log_prob_idx[index].first;
double prob_c = prob_idx[index].second; log_prob_type log_prob_c = log_prob_idx[index].second;
log_prob_type log_probs_prev;
if (c == blank_id) { if (c == blank_id) {
probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]); log_probs_prev = log_sum_exp(log_probs_b_prev[l],
log_probs_nb_prev[l]);
log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
log_prob_c+log_probs_prev);
} else { } else {
std::string last_char = l.substr(l.size()-1, 1); std::string last_char = l.substr(l.size()-1, 1);
std::string new_char = vocabulary[c]; std::string new_char = vocabulary[c];
std::string l_plus = l + new_char; std::string l_plus = l + new_char;
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) { if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0; log_probs_b_cur[l_plus] = -NUM_MAX;
log_probs_nb_cur[l_plus] = -NUM_MAX;
} }
if (last_char == new_char) { if (last_char == new_char) {
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]; log_probs_nb_cur[l_plus] = log_sum_exp(
probs_nb_cur[l] += prob_c * probs_nb_prev[l]; log_probs_nb_cur[l_plus],
log_prob_c+log_probs_b_prev[l]
);
log_probs_nb_cur[l] = log_sum_exp(
log_probs_nb_cur[l],
log_prob_c+log_probs_nb_prev[l]
);
} else if (new_char == " ") { } else if (new_char == " ") {
double score = 1.0; float score = 0.0;
if (ext_scorer != NULL && l.size() > 1) { if (ext_scorer != NULL && l.size() > 1) {
score = ext_scorer->get_score(l.substr(1)); score = ext_scorer->get_score(l.substr(1), true);
} }
probs_nb_cur[l_plus] += score * prob_c * ( log_probs_prev = log_sum_exp(log_probs_b_prev[l],
probs_b_prev[l] + probs_nb_prev[l]); log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
score + log_prob_c + log_probs_prev
);
} else { } else {
probs_nb_cur[l_plus] += prob_c * ( log_probs_prev = log_sum_exp(log_probs_b_prev[l],
probs_b_prev[l] + probs_nb_prev[l]); log_probs_nb_prev[l]);
log_probs_nb_cur[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
log_prob_c+log_probs_prev
);
} }
prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus]; prefix_set_next[l_plus] = log_sum_exp(
log_probs_nb_cur[l_plus],
log_probs_b_cur[l_plus]
);
} }
} }
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]; prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
log_probs_nb_cur[l]);
} }
probs_b_prev = probs_b_cur; log_probs_b_prev = log_probs_b_cur;
probs_nb_prev = probs_nb_cur; log_probs_nb_prev = log_probs_nb_cur;
std::vector<std::pair<std::string, double> > std::vector<std::pair<std::string, log_prob_type> >
prefix_vec_next(prefix_set_next.begin(), prefix_vec_next(prefix_set_next.begin(),
prefix_set_next.end()); prefix_set_next.end());
std::sort(prefix_vec_next.begin(), std::sort(prefix_vec_next.begin(),
prefix_vec_next.end(), prefix_vec_next.end(),
pair_comp_second_rev<std::string, double>); pair_comp_second_rev<std::string, log_prob_type>);
int k = beam_size<prefix_vec_next.size() ? beam_size:prefix_vec_next.size(); int num_prefixes_next = prefix_vec_next.size();
prefix_set_prev = std::map<std::string, double> int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
(prefix_vec_next.begin(), prefix_vec_next.begin()+k); prefix_set_prev = std::map<std::string, log_prob_type> (
prefix_vec_next.begin(),
prefix_vec_next.begin() + k
);
} }
// post processing // post processing
std::vector<std::pair<double, std::string> > beam_result; std::vector<std::pair<double, std::string> > beam_result;
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin(); for (std::map<std::string, log_prob_type>::iterator
it != prefix_set_prev.end(); it++) { it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
if (it->second > 0.0 && it->first.size() > 1) { if (it->second > -NUM_MAX && it->first.size() > 1) {
double prob = it->second; log_prob_type log_prob = it->second;
std::string sentence = it->first.substr(1); std::string sentence = it->first.substr(1);
// scoring the last word // scoring the last word
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') { if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
prob = prob * ext_scorer->get_score(sentence); log_prob = log_prob + ext_scorer->get_score(sentence, true);
}
if (log_prob > -NUM_MAX) {
std::pair<double, std::string> cur_result(log_prob, sentence);
beam_result.push_back(cur_result);
} }
double log_prob = log(prob);
beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
} }
} }
// sort the result and return // sort the result and return

@ -31,5 +31,9 @@ std::vector<std::pair<double, std::string> >
Scorer *ext_scorer=NULL, Scorer *ext_scorer=NULL,
bool nproc=false bool nproc=false
); );
/* CTC Best Path Decoder
*/
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
std::vector<std::string> vocabulary);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

@ -89,10 +89,15 @@ void Scorer::reset_params(float alpha, float beta) {
this->_beta = beta; this->_beta = beta;
} }
double Scorer::get_score(std::string sentence) { double Scorer::get_score(std::string sentence, bool log) {
double lm_score = language_model_score(sentence); double lm_score = language_model_score(sentence);
int word_cnt = word_count(sentence); int word_cnt = word_count(sentence);
double final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta); double final_score = 0.0;
if (log == false) {
final_score = pow(10, _alpha*lm_score) * pow(word_cnt, _beta);
} else {
final_score = _alpha*lm_score*std::log(10) + _beta*std::log(word_cnt);
}
return final_score; return final_score;
} }

@ -30,7 +30,7 @@ public:
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// get the final score // get the final score
double get_score(std::string); double get_score(std::string, bool log=false);
}; };
#endif //SCORER_H_ #endif //SCORER_H_

@ -0,0 +1,22 @@
"""Contains various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from swig_ctc_beam_search_decoder import ctc_beam_search_decoder as beam_search_decoder
from swig_ctc_beam_search_decoder import ctc_best_path_decoder as best_path__decoder
def ctc_best_path_decoder(probs_seq, vocabulary):
best_path__decoder(probs_seq.to_list(), vocabulary)
def ctc_beam_search_decoder(
probs_seq,
beam_size,
vocabulary,
blank_id,
cutoff_prob=1.0,
ext_scoring_func=None, ):
beam_search_decoder(probs_seq.to_list(), beam_size, vocabulary, blank_id,
cutoff_prob, ext_scoring_func)
Loading…
Cancel
Save