|
|
@ -15,10 +15,10 @@ bool pair_comp_second_rev(const std::pair<T1, T2> a, const std::pair<T1, T2> b)
|
|
|
|
return a.second > b.second;
|
|
|
|
return a.second > b.second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/* CTC beam search decoder in C++, the interface is consistent with the original
|
|
|
|
/* CTC beam search decoder in C++, the interface is consistent with the original
|
|
|
|
decoder in Python version.
|
|
|
|
decoder in Python version.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
std::vector<std::pair<double, std::string> >
|
|
|
|
std::vector<std::pair<double, std::string> >
|
|
|
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
|
|
|
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
|
|
|
int beam_size,
|
|
|
|
int beam_size,
|
|
|
|
std::vector<std::string> vocabulary,
|
|
|
|
std::vector<std::string> vocabulary,
|
|
|
@ -29,15 +29,15 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
)
|
|
|
|
)
|
|
|
|
{
|
|
|
|
{
|
|
|
|
int num_time_steps = probs_seq.size();
|
|
|
|
int num_time_steps = probs_seq.size();
|
|
|
|
|
|
|
|
|
|
|
|
// assign space ID
|
|
|
|
// assign space ID
|
|
|
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
|
|
|
std::vector<std::string>::iterator it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
|
|
|
int space_id = it-vocabulary.begin();
|
|
|
|
int space_id = it-vocabulary.begin();
|
|
|
|
if(space_id >= vocabulary.size()) {
|
|
|
|
if(space_id >= vocabulary.size()) {
|
|
|
|
std::cout<<"The character space is not in the vocabulary!";
|
|
|
|
std::cout<<"The character space is not in the vocabulary!";
|
|
|
|
exit(1);
|
|
|
|
exit(1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// initialize
|
|
|
|
// initialize
|
|
|
|
// two sets containing selected and candidate prefixes respectively
|
|
|
|
// two sets containing selected and candidate prefixes respectively
|
|
|
|
std::map<std::string, double> prefix_set_prev, prefix_set_next;
|
|
|
|
std::map<std::string, double> prefix_set_prev, prefix_set_next;
|
|
|
@ -47,7 +47,7 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
prefix_set_prev["\t"] = 1.0;
|
|
|
|
prefix_set_prev["\t"] = 1.0;
|
|
|
|
probs_b_prev["\t"] = 1.0;
|
|
|
|
probs_b_prev["\t"] = 1.0;
|
|
|
|
probs_nb_prev["\t"] = 0.0;
|
|
|
|
probs_nb_prev["\t"] = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
for (int time_step=0; time_step<num_time_steps; time_step++) {
|
|
|
|
for (int time_step=0; time_step<num_time_steps; time_step++) {
|
|
|
|
prefix_set_next.clear();
|
|
|
|
prefix_set_next.clear();
|
|
|
|
probs_b_cur.clear();
|
|
|
|
probs_b_cur.clear();
|
|
|
@ -70,8 +70,8 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
}
|
|
|
|
}
|
|
|
|
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
|
|
|
|
prob_idx = std::vector<std::pair<int, double> >(prob_idx.begin(), prob_idx.begin()+cutoff_len);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// extend prefix
|
|
|
|
// extend prefix
|
|
|
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
|
|
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
|
|
|
it != prefix_set_prev.end(); it++) {
|
|
|
|
it != prefix_set_prev.end(); it++) {
|
|
|
|
std::string l = it->first;
|
|
|
|
std::string l = it->first;
|
|
|
|
if( prefix_set_next.find(l) == prefix_set_next.end()) {
|
|
|
|
if( prefix_set_next.find(l) == prefix_set_next.end()) {
|
|
|
@ -109,12 +109,12 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
|
|
|
|
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
probs_b_prev = probs_b_cur;
|
|
|
|
probs_b_prev = probs_b_cur;
|
|
|
|
probs_nb_prev = probs_nb_cur;
|
|
|
|
probs_nb_prev = probs_nb_cur;
|
|
|
|
std::vector<std::pair<std::string, double> >
|
|
|
|
std::vector<std::pair<std::string, double> >
|
|
|
|
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
|
|
|
|
prefix_vec_next(prefix_set_next.begin(), prefix_set_next.end());
|
|
|
|
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
|
|
|
|
std::sort(prefix_vec_next.begin(), prefix_vec_next.end(), pair_comp_second_rev<std::string, double>);
|
|
|
|
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
|
|
|
|
int k = beam_size<prefix_vec_next.size() ? beam_size : prefix_vec_next.size();
|
|
|
@ -124,7 +124,7 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
|
|
|
|
|
|
|
|
// post processing
|
|
|
|
// post processing
|
|
|
|
std::vector<std::pair<double, std::string> > beam_result;
|
|
|
|
std::vector<std::pair<double, std::string> > beam_result;
|
|
|
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
|
|
|
for (std::map<std::string, double>::iterator it = prefix_set_prev.begin();
|
|
|
|
it != prefix_set_prev.end(); it++) {
|
|
|
|
it != prefix_set_prev.end(); it++) {
|
|
|
|
if (it->second > 0.0 && it->first.size() > 1) {
|
|
|
|
if (it->second > 0.0 && it->first.size() > 1) {
|
|
|
|
double prob = it->second;
|
|
|
|
double prob = it->second;
|
|
|
@ -133,8 +133,8 @@ std::vector<std::pair<double, std::string> >
|
|
|
|
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
|
|
|
|
if (ext_scorer != NULL && sentence[sentence.size()-1] != ' ') {
|
|
|
|
prob = prob * ext_scorer->get_score(sentence);
|
|
|
|
prob = prob * ext_scorer->get_score(sentence);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
double log_prob = log(it->second);
|
|
|
|
double log_prob = log(prob);
|
|
|
|
beam_result.push_back(std::pair<double, std::string>(log_prob, it->first));
|
|
|
|
beam_result.push_back(std::pair<double, std::string>(log_prob, sentence));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// sort the result and return
|
|
|
|
// sort the result and return
|
|
|
|