You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
10 KiB
261 lines
10 KiB
8 years ago
|
#include <iostream>
|
||
|
#include <map>
|
||
|
#include <algorithm>
|
||
|
#include <utility>
|
||
|
#include <cmath>
|
||
8 years ago
|
#include <limits>
|
||
8 years ago
|
#include "ctc_decoders.h"
|
||
8 years ago
|
|
||
8 years ago
|
typedef double log_prob_type;
|
||
8 years ago
|
|
||
8 years ago
|
template <typename T1, typename T2>
|
||
8 years ago
|
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
|
||
|
{
|
||
8 years ago
|
return a.first > b.first;
|
||
|
}
|
||
|
|
||
|
template <typename T1, typename T2>
|
||
8 years ago
|
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
|
||
|
{
|
||
8 years ago
|
return a.second > b.second;
|
||
|
}
|
||
|
|
||
8 years ago
|
template <typename T>
|
||
|
T log_sum_exp(T x, T y)
|
||
|
{
|
||
|
static T num_min = -std::numeric_limits<T>::max();
|
||
8 years ago
|
if (x <= num_min) return y;
|
||
|
if (y <= num_min) return x;
|
||
8 years ago
|
T xmax = std::max(x, y);
|
||
|
return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax;
|
||
|
}
|
||
|
|
||
|
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
|
||
|
std::vector<std::string> vocabulary) {
|
||
|
// dimension check
|
||
|
int num_time_steps = probs_seq.size();
|
||
|
for (int i=0; i<num_time_steps; i++) {
|
||
|
if (probs_seq[i].size() != vocabulary.size()+1) {
|
||
|
std::cout<<"The shape of probs_seq does not match"
|
||
|
<<" with the shape of the vocabulary!"<<std::endl;
|
||
|
exit(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int blank_id = vocabulary.size();
|
||
|
|
||
|
std::vector<int> max_idx_vec;
|
||
|
double max_prob = 0.0;
|
||
|
int max_idx = 0;
|
||
|
for (int i=0; i<num_time_steps; i++) {
|
||
|
for (int j=0; j<probs_seq[i].size(); j++) {
|
||
|
if (max_prob < probs_seq[i][j]) {
|
||
|
max_idx = j;
|
||
|
max_prob = probs_seq[i][j];
|
||
|
}
|
||
|
}
|
||
|
max_idx_vec.push_back(max_idx);
|
||
|
max_prob = 0.0;
|
||
|
max_idx = 0;
|
||
|
}
|
||
|
|
||
|
std::vector<int> idx_vec;
|
||
|
for (int i=0; i<max_idx_vec.size(); i++) {
|
||
|
if ((i == 0) || ((i>0) && max_idx_vec[i]!=max_idx_vec[i-1])) {
|
||
|
idx_vec.push_back(max_idx_vec[i]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
std::string best_path_result;
|
||
|
for (int i=0; i<idx_vec.size(); i++) {
|
||
|
if (idx_vec[i] != blank_id) {
|
||
8 years ago
|
best_path_result += vocabulary[idx_vec[i]];
|
||
8 years ago
|
}
|
||
|
}
|
||
|
return best_path_result;
|
||
|
}
|
||
|
|
||
8 years ago
|
std::vector<std::pair<double, std::string> >
|
||
8 years ago
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
||
|
int beam_size,
|
||
|
std::vector<std::string> vocabulary,
|
||
|
int blank_id,
|
||
|
double cutoff_prob,
|
||
8 years ago
|
LmScorer *ext_scorer,
|
||
8 years ago
|
bool nproc) {
|
||
|
// dimension check
|
||
8 years ago
|
int num_time_steps = probs_seq.size();
|
||
8 years ago
|
for (int i=0; i<num_time_steps; i++) {
|
||
|
if (probs_seq[i].size() != vocabulary.size()+1) {
|
||
8 years ago
|
std::cout << " The shape of probs_seq does not match"
|
||
|
<< " with the shape of the vocabulary!" << std::endl;
|
||
8 years ago
|
exit(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// blank_id check
|
||
|
if (blank_id > vocabulary.size()) {
|
||
8 years ago
|
std::cout << " Invalid blank_id! " << std::endl;
|
||
8 years ago
|
exit(1);
|
||
|
}
|
||
8 years ago
|
|
||
|
// assign space ID
|
||
8 years ago
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
|
||
|
vocabulary.end(), " ");
|
||
|
int space_id = it - vocabulary.begin();
|
||
8 years ago
|
if(space_id >= vocabulary.size()) {
|
||
8 years ago
|
std::cout << " The character space is not in the vocabulary!"<<std::endl;
|
||
8 years ago
|
exit(1);
|
||
8 years ago
|
}
|
||
8 years ago
|
|
||
8 years ago
|
// initialize
|
||
|
// two sets containing selected and candidate prefixes respectively
|
||
8 years ago
|
std::map<std::string, log_prob_type> prefix_set_prev, prefix_set_next;
|
||
8 years ago
|
// probability of prefixes ending with blank and non-blank
|
||
8 years ago
|
std::map<std::string, log_prob_type> log_probs_b_prev, log_probs_nb_prev;
|
||
|
std::map<std::string, log_prob_type> log_probs_b_cur, log_probs_nb_cur;
|
||
|
|
||
|
static log_prob_type NUM_MAX = std::numeric_limits<log_prob_type>::max();
|
||
|
prefix_set_prev["\t"] = 0.0;
|
||
|
log_probs_b_prev["\t"] = 0.0;
|
||
|
log_probs_nb_prev["\t"] = -NUM_MAX;
|
||
8 years ago
|
|
||
8 years ago
|
for (int time_step=0; time_step<num_time_steps; time_step++) {
|
||
|
prefix_set_next.clear();
|
||
8 years ago
|
log_probs_b_cur.clear();
|
||
|
log_probs_nb_cur.clear();
|
||
8 years ago
|
std::vector<double> prob = probs_seq[time_step];
|
||
|
|
||
|
std::vector<std::pair<int, double> > prob_idx;
|
||
|
for (int i=0; i<prob.size(); i++) {
|
||
|
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
||
|
}
|
||
8 years ago
|
|
||
8 years ago
|
// pruning of vacobulary
|
||
8 years ago
|
int cutoff_len = prob.size();
|
||
8 years ago
|
if (cutoff_prob < 1.0) {
|
||
8 years ago
|
std::sort(prob_idx.begin(),
|
||
|
prob_idx.end(),
|
||
8 years ago
|
pair_comp_second_rev<int, double>);
|
||
8 years ago
|
double cum_prob = 0.0;
|
||
|
cutoff_len = 0;
|
||
8 years ago
|
for (int i=0; i<prob_idx.size(); i++) {
|
||
|
cum_prob += prob_idx[i].second;
|
||
|
cutoff_len += 1;
|
||
|
if (cum_prob >= cutoff_prob) break;
|
||
|
}
|
||
8 years ago
|
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
|
||
8 years ago
|
prob_idx.begin() + cutoff_len);
|
||
8 years ago
|
}
|
||
8 years ago
|
|
||
|
std::vector<std::pair<int, log_prob_type> > log_prob_idx;
|
||
|
for (int i=0; i<cutoff_len; i++) {
|
||
|
log_prob_idx.push_back(std::pair<int, log_prob_type>
|
||
|
(prob_idx[i].first, log(prob_idx[i].second)));
|
||
|
}
|
||
|
|
||
8 years ago
|
// extend prefix
|
||
8 years ago
|
for (std::map<std::string, log_prob_type>::iterator
|
||
|
it = prefix_set_prev.begin();
|
||
8 years ago
|
it != prefix_set_prev.end(); it++) {
|
||
|
std::string l = it->first;
|
||
|
if( prefix_set_next.find(l) == prefix_set_next.end()) {
|
||
8 years ago
|
log_probs_b_cur[l] = log_probs_nb_cur[l] = -NUM_MAX;
|
||
8 years ago
|
}
|
||
|
|
||
8 years ago
|
for (int index=0; index<log_prob_idx.size(); index++) {
|
||
|
int c = log_prob_idx[index].first;
|
||
|
log_prob_type log_prob_c = log_prob_idx[index].second;
|
||
|
log_prob_type log_probs_prev;
|
||
8 years ago
|
if (c == blank_id) {
|
||
8 years ago
|
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
|
||
|
log_probs_nb_prev[l]);
|
||
|
log_probs_b_cur[l] = log_sum_exp(log_probs_b_cur[l],
|
||
|
log_prob_c+log_probs_prev);
|
||
8 years ago
|
} else {
|
||
|
std::string last_char = l.substr(l.size()-1, 1);
|
||
|
std::string new_char = vocabulary[c];
|
||
8 years ago
|
std::string l_plus = l + new_char;
|
||
8 years ago
|
|
||
|
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
|
||
8 years ago
|
log_probs_b_cur[l_plus] = -NUM_MAX;
|
||
|
log_probs_nb_cur[l_plus] = -NUM_MAX;
|
||
8 years ago
|
}
|
||
|
if (last_char == new_char) {
|
||
8 years ago
|
log_probs_nb_cur[l_plus] = log_sum_exp(
|
||
|
log_probs_nb_cur[l_plus],
|
||
|
log_prob_c+log_probs_b_prev[l]
|
||
|
);
|
||
|
log_probs_nb_cur[l] = log_sum_exp(
|
||
|
log_probs_nb_cur[l],
|
||
|
log_prob_c+log_probs_nb_prev[l]
|
||
|
);
|
||
8 years ago
|
} else if (new_char == " ") {
|
||
8 years ago
|
float score = 0.0;
|
||
8 years ago
|
if (ext_scorer != NULL && l.size() > 1) {
|
||
8 years ago
|
score = ext_scorer->get_score(l.substr(1), true);
|
||
8 years ago
|
}
|
||
8 years ago
|
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
|
||
|
log_probs_nb_prev[l]);
|
||
|
log_probs_nb_cur[l_plus] = log_sum_exp(
|
||
|
log_probs_nb_cur[l_plus],
|
||
|
score + log_prob_c + log_probs_prev
|
||
|
);
|
||
8 years ago
|
} else {
|
||
8 years ago
|
log_probs_prev = log_sum_exp(log_probs_b_prev[l],
|
||
|
log_probs_nb_prev[l]);
|
||
|
log_probs_nb_cur[l_plus] = log_sum_exp(
|
||
|
log_probs_nb_cur[l_plus],
|
||
|
log_prob_c+log_probs_prev
|
||
|
);
|
||
8 years ago
|
}
|
||
8 years ago
|
prefix_set_next[l_plus] = log_sum_exp(
|
||
|
log_probs_nb_cur[l_plus],
|
||
|
log_probs_b_cur[l_plus]
|
||
|
);
|
||
8 years ago
|
}
|
||
|
}
|
||
|
|
||
8 years ago
|
prefix_set_next[l] = log_sum_exp(log_probs_b_cur[l],
|
||
|
log_probs_nb_cur[l]);
|
||
8 years ago
|
}
|
||
|
|
||
8 years ago
|
log_probs_b_prev = log_probs_b_cur;
|
||
|
log_probs_nb_prev = log_probs_nb_cur;
|
||
|
std::vector<std::pair<std::string, log_prob_type> >
|
||
8 years ago
|
prefix_vec_next(prefix_set_next.begin(),
|
||
|
prefix_set_next.end());
|
||
|
std::sort(prefix_vec_next.begin(),
|
||
|
prefix_vec_next.end(),
|
||
8 years ago
|
pair_comp_second_rev<std::string, log_prob_type>);
|
||
|
int num_prefixes_next = prefix_vec_next.size();
|
||
|
int k = beam_size<num_prefixes_next ? beam_size : num_prefixes_next;
|
||
|
prefix_set_prev = std::map<std::string, log_prob_type> (
|
||
|
prefix_vec_next.begin(),
|
||
|
prefix_vec_next.begin() + k
|
||
|
);
|
||
8 years ago
|
}
|
||
|
|
||
|
// post processing
|
||
|
std::vector<std::pair<double, std::string> > beam_result;
|
||
8 years ago
|
for (std::map<std::string, log_prob_type>::iterator
|
||
|
it = prefix_set_prev.begin(); it != prefix_set_prev.end(); it++) {
|
||
|
if (it->second > -NUM_MAX && it->first.size() > 1) {
|
||
|
log_prob_type log_prob = it->second;
|
||
8 years ago
|
std::string sentence = it->first.substr(1);
|
||
|
// scoring the last word
|
||
|
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
|
||
8 years ago
|
log_prob = log_prob + ext_scorer->get_score(sentence, true);
|
||
|
}
|
||
|
if (log_prob > -NUM_MAX) {
|
||
|
std::pair<double, std::string> cur_result(log_prob, sentence);
|
||
|
beam_result.push_back(cur_result);
|
||
8 years ago
|
}
|
||
|
}
|
||
|
}
|
||
|
// sort the result and return
|
||
8 years ago
|
std::sort(beam_result.begin(), beam_result.end(),
|
||
|
pair_comp_first_rev<double, std::string>);
|
||
8 years ago
|
return beam_result;
|
||
|
}
|