|
|
|
@ -6,35 +6,47 @@
|
|
|
|
|
#include "ctc_beam_search_decoder.h"
|
|
|
|
|
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
|
|
|
|
bool pair_comp_first_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
|
|
|
|
|
{
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b) {
|
|
|
|
|
bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
|
|
|
|
|
{
|
|
|
|
|
return a.second > b.second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* CTC beam search decoder in C++, the interface is consistent with the original
|
|
|
|
|
decoder in Python version.
|
|
|
|
|
*/
|
|
|
|
|
std::vector<std::pair<double, std::string> >
|
|
|
|
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
|
|
|
|
int beam_size,
|
|
|
|
|
std::vector<std::string> vocabulary,
|
|
|
|
|
int blank_id,
|
|
|
|
|
double cutoff_prob,
|
|
|
|
|
Scorer *ext_scorer,
|
|
|
|
|
bool nproc
|
|
|
|
|
)
|
|
|
|
|
{
|
|
|
|
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
|
|
|
|
int beam_size,
|
|
|
|
|
std::vector<std::string> vocabulary,
|
|
|
|
|
int blank_id,
|
|
|
|
|
double cutoff_prob,
|
|
|
|
|
Scorer *ext_scorer,
|
|
|
|
|
bool nproc) {
|
|
|
|
|
// dimension check
|
|
|
|
|
int num_time_steps = probs_seq.size();
|
|
|
|
|
for (int i=0; i<num_time_steps; i++) {
|
|
|
|
|
if (probs_seq[i].size() != vocabulary.size()+1) {
|
|
|
|
|
std::cout<<"The shape of probs_seq does not match"
|
|
|
|
|
<<" with the shape of the vocabulary!"<<std::endl;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// blank_id check
|
|
|
|
|
if (blank_id > vocabulary.size()) {
|
|
|
|
|
std::cout<<"Invalid blank_id!"<<std::endl;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// assign space ID
|
|
|
|
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
|
|
|
|
int space_id = it-vocabulary.begin();
|
|
|
|
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
|
|
|
|
|
vocabulary.end(), " ");
|
|
|
|
|
int space_id = it - vocabulary.begin();
|
|
|
|
|
if(space_id >= vocabulary.size()) {
|
|
|
|
|
std::cout<<"The character space is not in the vocabulary!";
|
|
|
|
|
std::cout<<"The character space is not in the vocabulary!"<<std::endl;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
}
|
|
|
|
|
// pruning of vacobulary
|
|
|
|
|
if (cutoff_prob < 1.0) {
|
|
|
|
|
std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
|
|
|
|
std::sort(prob_idx.begin(), prob_idx.end(),
|
|
|
|
|
pair_comp_second_rev<int, double>);
|
|
|
|
|
float cum_prob = 0.0;
|
|
|
|
|
int cutoff_len = 0;
|
|
|
|
|
for (int i=0; i<prob_idx.size(); i++) {
|
|
|
|
@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
cutoff_len += 1;
|
|
|
|
|
if (cum_prob >= cutoff_prob) break;
|
|
|
|
|
}
|
|
|
|
|
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
|
|
|
|
|
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
|
|
|
|
|
prob_idx.begin() + cutoff_len);
|
|
|
|
|
}
|
|
|
|
|
// extend prefix
|
|
|
|
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
|
|
|
@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
int c = prob_idx[index].first;
|
|
|
|
|
double prob_c = prob_idx[index].second;
|
|
|
|
|
if (c == blank_id) {
|
|
|
|
|
probs_b_cur[l] += prob_c*(probs_b_prev[l]+probs_nb_prev[l]);
|
|
|
|
|
probs_b_cur[l] += prob_c * (probs_b_prev[l] + probs_nb_prev[l]);
|
|
|
|
|
} else {
|
|
|
|
|
std::string last_char = l.substr(l.size()-1, 1);
|
|
|
|
|
std::string new_char = vocabulary[c];
|
|
|
|
|
std::string l_plus = l+new_char;
|
|
|
|
|
std::string l_plus = l + new_char;
|
|
|
|
|
|
|
|
|
|
if( prefix_set_next.find(l_plus) == prefix_set_next.end()) {
|
|
|
|
|
probs_b_cur[l_plus] = probs_nb_cur[l_plus] = 0.0;
|
|
|
|
@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
probs_nb_cur[l_plus] += prob_c * (
|
|
|
|
|
probs_b_prev[l] + probs_nb_prev[l]);
|
|
|
|
|
}
|
|
|
|
|
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus];
|
|
|
|
|
prefix_set_next[l_plus] = probs_nb_cur[l_plus] + probs_b_cur[l_plus];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
|
|
|
|
|
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
probs_b_prev = probs_b_cur;
|
|
|
|
|
probs_nb_prev = probs_nb_cur;
|
|
|
|
|
std::vector<std::pair<std::string, double> >
|
|
|
|
|
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
|
|
|
|
|
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
|
|
|
|
|
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
|
|
|
|
|
prefix_vec_next(prefix_set_next.begin(),
|
|
|
|
|
prefix_set_next.end());
|
|
|
|
|
std::sort(prefix_vec_next.begin(),
|
|
|
|
|
prefix_vec_next.end(),
|
|
|
|
|
pair_comp_second_rev<std::string, double>);
|
|
|
|
|
int k = beam_size<prefix_vec_next.size() ? beam_size:prefix_vec_next.size();
|
|
|
|
|
prefix_set_prev = std::map<std::string, double>
|
|
|
|
|
(prefix_vec_next.begin(), prefix_vec_next.begin()+k);
|
|
|
|
|
}
|
|
|
|
@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// sort the result and return
|
|
|
|
|
std::sort(beam_result.begin(), beam_result.end(), pair_comp_first_rev<double, std::string>);
|
|
|
|
|
std::sort(beam_result.begin(), beam_result.end(),
|
|
|
|
|
pair_comp_first_rev<double, std::string>);
|
|
|
|
|
return beam_result;
|
|
|
|
|
}
|
|
|
|
|