parent
a2ddfe8d9e
commit
5208b8e40f
@ -1,337 +1,329 @@
|
|||||||
#include <iostream>
|
#include "ctc_decoders.h"
|
||||||
#include <map>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include "fst/fstlib.h"
|
#include <map>
|
||||||
#include "ctc_decoders.h"
|
#include <utility>
|
||||||
|
#include "ThreadPool.h"
|
||||||
#include "decoder_utils.h"
|
#include "decoder_utils.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
#include "path_trie.h"
|
#include "path_trie.h"
|
||||||
#include "ThreadPool.h"
|
|
||||||
|
|
||||||
std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
|
std::string ctc_best_path_decoder(std::vector<std::vector<double>> probs_seq,
|
||||||
std::vector<std::string> vocabulary)
|
std::vector<std::string> vocabulary) {
|
||||||
{
|
// dimension check
|
||||||
// dimension check
|
int num_time_steps = probs_seq.size();
|
||||||
int num_time_steps = probs_seq.size();
|
for (int i = 0; i < num_time_steps; i++) {
|
||||||
for (int i=0; i<num_time_steps; i++) {
|
if (probs_seq[i].size() != vocabulary.size() + 1) {
|
||||||
if (probs_seq[i].size() != vocabulary.size()+1) {
|
std::cout << "The shape of probs_seq does not match"
|
||||||
std::cout << "The shape of probs_seq does not match"
|
<< " with the shape of the vocabulary!" << std::endl;
|
||||||
<< " with the shape of the vocabulary!" << std::endl;
|
exit(1);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
int blank_id = vocabulary.size();
|
|
||||||
|
int blank_id = vocabulary.size();
|
||||||
std::vector<int> max_idx_vec;
|
|
||||||
double max_prob = 0.0;
|
std::vector<int> max_idx_vec;
|
||||||
int max_idx = 0;
|
double max_prob = 0.0;
|
||||||
for (int i = 0; i < num_time_steps; i++) {
|
int max_idx = 0;
|
||||||
for (int j = 0; j < probs_seq[i].size(); j++) {
|
for (int i = 0; i < num_time_steps; i++) {
|
||||||
if (max_prob < probs_seq[i][j]) {
|
for (int j = 0; j < probs_seq[i].size(); j++) {
|
||||||
max_idx = j;
|
if (max_prob < probs_seq[i][j]) {
|
||||||
max_prob = probs_seq[i][j];
|
max_idx = j;
|
||||||
}
|
max_prob = probs_seq[i][j];
|
||||||
}
|
}
|
||||||
max_idx_vec.push_back(max_idx);
|
|
||||||
max_prob = 0.0;
|
|
||||||
max_idx = 0;
|
|
||||||
}
|
}
|
||||||
|
max_idx_vec.push_back(max_idx);
|
||||||
std::vector<int> idx_vec;
|
max_prob = 0.0;
|
||||||
for (int i = 0; i < max_idx_vec.size(); i++) {
|
max_idx = 0;
|
||||||
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) {
|
}
|
||||||
idx_vec.push_back(max_idx_vec[i]);
|
|
||||||
}
|
std::vector<int> idx_vec;
|
||||||
|
for (int i = 0; i < max_idx_vec.size(); i++) {
|
||||||
|
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
||||||
|
idx_vec.push_back(max_idx_vec[i]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string best_path_result;
|
std::string best_path_result;
|
||||||
for (int i = 0; i < idx_vec.size(); i++) {
|
for (int i = 0; i < idx_vec.size(); i++) {
|
||||||
if (idx_vec[i] != blank_id) {
|
if (idx_vec[i] != blank_id) {
|
||||||
best_path_result += vocabulary[idx_vec[i]];
|
best_path_result += vocabulary[idx_vec[i]];
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return best_path_result;
|
}
|
||||||
|
return best_path_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<double, std::string> >
|
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||||
ctc_beam_search_decoder(std::vector<std::vector<double> > probs_seq,
|
std::vector<std::vector<double>> probs_seq,
|
||||||
int beam_size,
|
int beam_size,
|
||||||
std::vector<std::string> vocabulary,
|
std::vector<std::string> vocabulary,
|
||||||
int blank_id,
|
int blank_id,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
int cutoff_top_n,
|
int cutoff_top_n,
|
||||||
Scorer *ext_scorer)
|
Scorer *extscorer) {
|
||||||
{
|
// dimension check
|
||||||
// dimension check
|
int num_time_steps = probs_seq.size();
|
||||||
int num_time_steps = probs_seq.size();
|
for (int i = 0; i < num_time_steps; i++) {
|
||||||
for (int i = 0; i < num_time_steps; i++) {
|
if (probs_seq[i].size() != vocabulary.size() + 1) {
|
||||||
if (probs_seq[i].size() != vocabulary.size() + 1) {
|
std::cout << " The shape of probs_seq does not match"
|
||||||
std::cout << " The shape of probs_seq does not match"
|
<< " with the shape of the vocabulary!" << std::endl;
|
||||||
<< " with the shape of the vocabulary!" << std::endl;
|
exit(1);
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// blank_id check
|
|
||||||
if (blank_id > vocabulary.size()) {
|
// blank_id check
|
||||||
std::cout << " Invalid blank_id! " << std::endl;
|
if (blank_id > vocabulary.size()) {
|
||||||
exit(1);
|
std::cout << " Invalid blank_id! " << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign space ID
|
||||||
|
std::vector<std::string>::iterator it =
|
||||||
|
std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||||
|
int space_id = it - vocabulary.begin();
|
||||||
|
// if no space in vocabulary
|
||||||
|
if (space_id >= vocabulary.size()) {
|
||||||
|
space_id = -2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// init prefixes' root
|
||||||
|
PathTrie root;
|
||||||
|
root.score = root.log_prob_b_prev = 0.0;
|
||||||
|
std::vector<PathTrie *> prefixes;
|
||||||
|
prefixes.push_back(&root);
|
||||||
|
|
||||||
|
if (extscorer != nullptr) {
|
||||||
|
if (extscorer->is_char_map_empty()) {
|
||||||
|
extscorer->set_char_map(vocabulary);
|
||||||
}
|
}
|
||||||
|
if (!extscorer->is_character_based()) {
|
||||||
// assign space ID
|
if (extscorer->dictionary == nullptr) {
|
||||||
std::vector<std::string>::iterator it = std::find(vocabulary.begin(),
|
// fill dictionary for fst
|
||||||
vocabulary.end(), " ");
|
extscorer->fill_dictionary(true);
|
||||||
int space_id = it - vocabulary.begin();
|
}
|
||||||
// if no space in vocabulary
|
auto fst_dict = static_cast<fst::StdVectorFst *>(extscorer->dictionary);
|
||||||
if(space_id >= vocabulary.size()) {
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||||
space_id = -2;
|
root.set_dictionary(dict_ptr);
|
||||||
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
|
root.set_matcher(matcher);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefix search over time
|
||||||
|
for (int time_step = 0; time_step < num_time_steps; time_step++) {
|
||||||
|
std::vector<double> prob = probs_seq[time_step];
|
||||||
|
std::vector<std::pair<int, double>> prob_idx;
|
||||||
|
for (int i = 0; i < prob.size(); i++) {
|
||||||
|
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// init prefixes' root
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
PathTrie root;
|
bool full_beam = false;
|
||||||
root._score = root._log_prob_b_prev = 0.0;
|
if (extscorer != nullptr) {
|
||||||
std::vector<PathTrie*> prefixes;
|
int num_prefixes = std::min((int)prefixes.size(), beam_size);
|
||||||
prefixes.push_back(&root);
|
std::sort(
|
||||||
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) -
|
||||||
|
std::max(0.0, extscorer->beta);
|
||||||
|
full_beam = (num_prefixes == beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
if ( ext_scorer != nullptr) {
|
// pruning of vacobulary
|
||||||
if (ext_scorer->is_char_map_empty()) {
|
int cutoff_len = prob.size();
|
||||||
ext_scorer->set_char_map(vocabulary);
|
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
|
||||||
}
|
std::sort(
|
||||||
if (!ext_scorer->is_character_based()) {
|
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||||
if (ext_scorer->dictionary == nullptr) {
|
if (cutoff_prob < 1.0) {
|
||||||
// fill dictionary for fst
|
double cum_prob = 0.0;
|
||||||
ext_scorer->fill_dictionary(true);
|
cutoff_len = 0;
|
||||||
}
|
for (int i = 0; i < prob_idx.size(); i++) {
|
||||||
auto fst_dict = static_cast<fst::StdVectorFst*>
|
cum_prob += prob_idx[i].second;
|
||||||
(ext_scorer->dictionary);
|
cutoff_len += 1;
|
||||||
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
|
if (cum_prob >= cutoff_prob) break;
|
||||||
root.set_dictionary(dict_ptr);
|
|
||||||
auto matcher = std::make_shared<FSTMATCH>
|
|
||||||
(*dict_ptr, fst::MATCH_INPUT);
|
|
||||||
root.set_matcher(matcher);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
cutoff_len = std::min(cutoff_len, cutoff_top_n);
|
||||||
|
prob_idx = std::vector<std::pair<int, double>>(
|
||||||
|
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
||||||
|
}
|
||||||
|
std::vector<std::pair<int, float>> log_prob_idx;
|
||||||
|
for (int i = 0; i < cutoff_len; i++) {
|
||||||
|
log_prob_idx.push_back(std::pair<int, float>(
|
||||||
|
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// prefix search over time
|
// loop over chars
|
||||||
for (int time_step = 0; time_step < num_time_steps; time_step++) {
|
for (int index = 0; index < log_prob_idx.size(); index++) {
|
||||||
std::vector<double> prob = probs_seq[time_step];
|
auto c = log_prob_idx[index].first;
|
||||||
std::vector<std::pair<int, double> > prob_idx;
|
float log_prob_c = log_prob_idx[index].second;
|
||||||
for (int i=0; i<prob.size(); i++) {
|
|
||||||
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
float min_cutoff = -NUM_FLT_INF;
|
for (int i = 0; i < prefixes.size() && i < beam_size; i++) {
|
||||||
bool full_beam = false;
|
auto prefix = prefixes[i];
|
||||||
if (ext_scorer != nullptr) {
|
|
||||||
int num_prefixes = std::min((int)prefixes.size(), beam_size);
|
|
||||||
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes,
|
|
||||||
prefix_compare);
|
|
||||||
min_cutoff = prefixes[num_prefixes-1]->_score + log(prob[blank_id])
|
|
||||||
- std::max(0.0, ext_scorer->beta);
|
|
||||||
full_beam = (num_prefixes == beam_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// pruning of vacobulary
|
|
||||||
int cutoff_len = prob.size();
|
|
||||||
if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) {
|
|
||||||
std::sort(prob_idx.begin(),
|
|
||||||
prob_idx.end(),
|
|
||||||
pair_comp_second_rev<int, double>);
|
|
||||||
if (cutoff_prob < 1.0) {
|
|
||||||
double cum_prob = 0.0;
|
|
||||||
cutoff_len = 0;
|
|
||||||
for (int i=0; i<prob_idx.size(); i++) {
|
|
||||||
cum_prob += prob_idx[i].second;
|
|
||||||
cutoff_len += 1;
|
|
||||||
if (cum_prob >= cutoff_prob) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cutoff_len = std::min(cutoff_len, cutoff_top_n);
|
|
||||||
prob_idx = std::vector<std::pair<int, double> >( prob_idx.begin(),
|
|
||||||
prob_idx.begin() + cutoff_len);
|
|
||||||
}
|
|
||||||
std::vector<std::pair<int, float> > log_prob_idx;
|
|
||||||
for (int i = 0; i < cutoff_len; i++) {
|
|
||||||
log_prob_idx.push_back(std::pair<int, float>
|
|
||||||
(prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop over chars
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
for (int index = 0; index < log_prob_idx.size(); index++) {
|
break;
|
||||||
auto c = log_prob_idx[index].first;
|
|
||||||
float log_prob_c = log_prob_idx[index].second;
|
|
||||||
|
|
||||||
for (int i = 0; i < prefixes.size() && i<beam_size; i++) {
|
|
||||||
auto prefix = prefixes[i];
|
|
||||||
|
|
||||||
if (full_beam && log_prob_c + prefix->_score < min_cutoff) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// blank
|
|
||||||
if (c == blank_id) {
|
|
||||||
prefix->_log_prob_b_cur = log_sum_exp(
|
|
||||||
prefix->_log_prob_b_cur,
|
|
||||||
log_prob_c + prefix->_score);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// repeated character
|
|
||||||
if (c == prefix->_character) {
|
|
||||||
prefix->_log_prob_nb_cur = log_sum_exp(
|
|
||||||
prefix->_log_prob_nb_cur,
|
|
||||||
log_prob_c + prefix->_log_prob_nb_prev);
|
|
||||||
}
|
|
||||||
// get new prefix
|
|
||||||
auto prefix_new = prefix->get_path_trie(c);
|
|
||||||
|
|
||||||
if (prefix_new != nullptr) {
|
|
||||||
float log_p = -NUM_FLT_INF;
|
|
||||||
|
|
||||||
if (c == prefix->_character
|
|
||||||
&& prefix->_log_prob_b_prev > -NUM_FLT_INF) {
|
|
||||||
log_p = log_prob_c + prefix->_log_prob_b_prev;
|
|
||||||
} else if (c != prefix->_character) {
|
|
||||||
log_p = log_prob_c + prefix->_score;
|
|
||||||
}
|
|
||||||
|
|
||||||
// language model scoring
|
|
||||||
if (ext_scorer != nullptr &&
|
|
||||||
(c == space_id || ext_scorer->is_character_based()) ) {
|
|
||||||
PathTrie *prefix_to_score = nullptr;
|
|
||||||
|
|
||||||
// skip scoring the space
|
|
||||||
if (ext_scorer->is_character_based()) {
|
|
||||||
prefix_to_score = prefix_new;
|
|
||||||
} else {
|
|
||||||
prefix_to_score = prefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
double score = 0.0;
|
|
||||||
std::vector<std::string> ngram;
|
|
||||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
|
||||||
score = ext_scorer->get_log_cond_prob(ngram) *
|
|
||||||
ext_scorer->alpha;
|
|
||||||
|
|
||||||
log_p += score;
|
|
||||||
log_p += ext_scorer->beta;
|
|
||||||
}
|
|
||||||
prefix_new->_log_prob_nb_cur = log_sum_exp(
|
|
||||||
prefix_new->_log_prob_nb_cur, log_p);
|
|
||||||
}
|
|
||||||
} // end of loop over prefix
|
|
||||||
} // end of loop over chars
|
|
||||||
|
|
||||||
prefixes.clear();
|
|
||||||
// update log probs
|
|
||||||
root.iterate_to_vec(prefixes);
|
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
|
||||||
if (prefixes.size() >= beam_size) {
|
|
||||||
std::nth_element(prefixes.begin(),
|
|
||||||
prefixes.begin() + beam_size,
|
|
||||||
prefixes.end(),
|
|
||||||
prefix_compare);
|
|
||||||
|
|
||||||
for (size_t i = beam_size; i < prefixes.size(); i++) {
|
|
||||||
prefixes[i]->remove();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // end of loop over time
|
// blank
|
||||||
|
if (c == blank_id) {
|
||||||
// compute aproximate ctc score as the return score
|
prefix->log_prob_b_cur =
|
||||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||||
double approx_ctc = prefixes[i]->_score;
|
continue;
|
||||||
|
}
|
||||||
if (ext_scorer != nullptr) {
|
// repeated character
|
||||||
std::vector<int> output;
|
if (c == prefix->character) {
|
||||||
prefixes[i]->get_path_vec(output);
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
size_t prefix_length = output.size();
|
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||||
auto words = ext_scorer->split_labels(output);
|
|
||||||
// remove word insert
|
|
||||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
|
||||||
// remove language model weight:
|
|
||||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words))
|
|
||||||
* ext_scorer->alpha;
|
|
||||||
}
|
}
|
||||||
|
// get new prefix
|
||||||
|
auto prefix_new = prefix->get_path_trie(c);
|
||||||
|
|
||||||
|
if (prefix_new != nullptr) {
|
||||||
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|
||||||
|
if (c == prefix->character &&
|
||||||
|
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||||
|
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||||
|
} else if (c != prefix->character) {
|
||||||
|
log_p = log_prob_c + prefix->score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// language model scoring
|
||||||
|
if (extscorer != nullptr &&
|
||||||
|
(c == space_id || extscorer->is_character_based())) {
|
||||||
|
PathTrie *prefix_toscore = nullptr;
|
||||||
|
|
||||||
|
// skip scoring the space
|
||||||
|
if (extscorer->is_character_based()) {
|
||||||
|
prefix_toscore = prefix_new;
|
||||||
|
} else {
|
||||||
|
prefix_toscore = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
prefixes[i]->_approx_ctc = approx_ctc;
|
double score = 0.0;
|
||||||
}
|
std::vector<std::string> ngram;
|
||||||
|
ngram = extscorer->make_ngram(prefix_toscore);
|
||||||
|
score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha;
|
||||||
|
|
||||||
// allow for the post processing
|
log_p += score;
|
||||||
std::vector<PathTrie*> space_prefixes;
|
log_p += extscorer->beta;
|
||||||
if (space_prefixes.empty()) {
|
}
|
||||||
for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) {
|
prefix_new->log_prob_nb_cur =
|
||||||
space_prefixes.push_back(prefixes[i]);
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
}
|
}
|
||||||
|
} // end of loop over prefix
|
||||||
|
} // end of loop over chars
|
||||||
|
|
||||||
|
prefixes.clear();
|
||||||
|
// update log probs
|
||||||
|
root.iterate_to_vec(prefixes);
|
||||||
|
|
||||||
|
// only preserve top beam_size prefixes
|
||||||
|
if (prefixes.size() >= beam_size) {
|
||||||
|
std::nth_element(prefixes.begin(),
|
||||||
|
prefixes.begin() + beam_size,
|
||||||
|
prefixes.end(),
|
||||||
|
prefix_compare);
|
||||||
|
|
||||||
|
for (size_t i = beam_size; i < prefixes.size(); i++) {
|
||||||
|
prefixes[i]->remove();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} // end of loop over time
|
||||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
|
||||||
std::vector<std::pair<double, std::string> > output_vecs;
|
// compute aproximate ctc score as the return score
|
||||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
|
||||||
std::vector<int> output;
|
double approx_ctc = prefixes[i]->score;
|
||||||
space_prefixes[i]->get_path_vec(output);
|
|
||||||
// convert index to string
|
if (extscorer != nullptr) {
|
||||||
std::string output_str;
|
std::vector<int> output;
|
||||||
for (int j = 0; j < output.size(); j++) {
|
prefixes[i]->get_path_vec(output);
|
||||||
output_str += vocabulary[output[j]];
|
size_t prefix_length = output.size();
|
||||||
}
|
auto words = extscorer->split_labels(output);
|
||||||
std::pair<double, std::string>
|
// remove word insert
|
||||||
output_pair(-space_prefixes[i]->_approx_ctc, output_str);
|
approx_ctc = approx_ctc - prefix_length * extscorer->beta;
|
||||||
output_vecs.emplace_back(output_pair);
|
// remove language model weight:
|
||||||
|
approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha;
|
||||||
}
|
}
|
||||||
|
|
||||||
return output_vecs;
|
prefixes[i]->approx_ctc = approx_ctc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow for the post processing
|
||||||
std::vector<std::vector<std::pair<double, std::string> > >
|
std::vector<PathTrie *> space_prefixes;
|
||||||
ctc_beam_search_decoder_batch(
|
if (space_prefixes.empty()) {
|
||||||
std::vector<std::vector<std::vector<double>>> probs_split,
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
|
||||||
int beam_size,
|
space_prefixes.push_back(prefixes[i]);
|
||||||
std::vector<std::string> vocabulary,
|
|
||||||
int blank_id,
|
|
||||||
int num_processes,
|
|
||||||
double cutoff_prob,
|
|
||||||
int cutoff_top_n,
|
|
||||||
Scorer *ext_scorer
|
|
||||||
) {
|
|
||||||
if (num_processes <= 0) {
|
|
||||||
std::cout << "num_processes must be nonnegative!" << std::endl;
|
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
// thread pool
|
}
|
||||||
ThreadPool pool(num_processes);
|
|
||||||
// number of samples
|
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||||
int batch_size = probs_split.size();
|
std::vector<std::pair<double, std::string>> output_vecs;
|
||||||
|
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
|
||||||
// scorer filling up
|
std::vector<int> output;
|
||||||
if ( ext_scorer != nullptr) {
|
space_prefixes[i]->get_path_vec(output);
|
||||||
if (ext_scorer->is_char_map_empty()) {
|
// convert index to string
|
||||||
ext_scorer->set_char_map(vocabulary);
|
std::string output_str;
|
||||||
}
|
for (int j = 0; j < output.size(); j++) {
|
||||||
if(!ext_scorer->is_character_based()
|
output_str += vocabulary[output[j]];
|
||||||
&& ext_scorer->dictionary == nullptr) {
|
|
||||||
// init dictionary
|
|
||||||
ext_scorer->fill_dictionary(true);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
|
||||||
|
output_str);
|
||||||
|
output_vecs.emplace_back(output_pair);
|
||||||
|
}
|
||||||
|
|
||||||
// enqueue the tasks of decoding
|
return output_vecs;
|
||||||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
}
|
||||||
for (int i = 0; i < batch_size; i++) {
|
|
||||||
res.emplace_back(
|
|
||||||
pool.enqueue(ctc_beam_search_decoder, probs_split[i],
|
|
||||||
beam_size, vocabulary, blank_id, cutoff_prob,
|
|
||||||
cutoff_top_n, ext_scorer)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// get decoding results
|
std::vector<std::vector<std::pair<double, std::string>>>
|
||||||
std::vector<std::vector<std::pair<double, std::string> > > batch_results;
|
ctc_beam_search_decoder_batch(
|
||||||
for (int i = 0; i < batch_size; i++) {
|
std::vector<std::vector<std::vector<double>>> probs_split,
|
||||||
batch_results.emplace_back(res[i].get());
|
int beam_size,
|
||||||
|
std::vector<std::string> vocabulary,
|
||||||
|
int blank_id,
|
||||||
|
int num_processes,
|
||||||
|
double cutoff_prob,
|
||||||
|
int cutoff_top_n,
|
||||||
|
Scorer *extscorer) {
|
||||||
|
if (num_processes <= 0) {
|
||||||
|
std::cout << "num_processes must be nonnegative!" << std::endl;
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
// thread pool
|
||||||
|
ThreadPool pool(num_processes);
|
||||||
|
// number of samples
|
||||||
|
int batch_size = probs_split.size();
|
||||||
|
|
||||||
|
// scorer filling up
|
||||||
|
if (extscorer != nullptr) {
|
||||||
|
if (extscorer->is_char_map_empty()) {
|
||||||
|
extscorer->set_char_map(vocabulary);
|
||||||
|
}
|
||||||
|
if (!extscorer->is_character_based() &&
|
||||||
|
extscorer->dictionary == nullptr) {
|
||||||
|
// init dictionary
|
||||||
|
extscorer->fill_dictionary(true);
|
||||||
}
|
}
|
||||||
return batch_results;
|
}
|
||||||
|
|
||||||
|
// enqueue the tasks of decoding
|
||||||
|
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||||
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||||
|
probs_split[i],
|
||||||
|
beam_size,
|
||||||
|
vocabulary,
|
||||||
|
blank_id,
|
||||||
|
cutoff_prob,
|
||||||
|
cutoff_top_n,
|
||||||
|
extscorer));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get decoding results
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||||
|
for (int i = 0; i < batch_size; i++) {
|
||||||
|
batch_results.emplace_back(res[i].get());
|
||||||
|
}
|
||||||
|
return batch_results;
|
||||||
}
|
}
|
||||||
|
@ -1,113 +1,111 @@
|
|||||||
#include <limits>
|
#include "decoder_utils.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include "decoder_utils.h"
|
#include <limits>
|
||||||
|
|
||||||
size_t get_utf8_str_len(const std::string& str) {
|
size_t get_utf8_str_len(const std::string& str) {
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
for (char c : str) {
|
for (char c : str) {
|
||||||
str_len += ((c & 0xc0) != 0x80);
|
str_len += ((c & 0xc0) != 0x80);
|
||||||
}
|
}
|
||||||
return str_len;
|
return str_len;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> split_utf8_str(const std::string& str)
|
std::vector<std::string> split_utf8_str(const std::string& str) {
|
||||||
{
|
|
||||||
std::vector<std::string> result;
|
std::vector<std::string> result;
|
||||||
std::string out_str;
|
std::string out_str;
|
||||||
|
|
||||||
for (char c : str)
|
for (char c : str) {
|
||||||
|
if ((c & 0xc0) != 0x80) // new UTF-8 character
|
||||||
{
|
{
|
||||||
if ((c & 0xc0) != 0x80) //new UTF-8 character
|
if (!out_str.empty()) {
|
||||||
{
|
result.push_back(out_str);
|
||||||
if (!out_str.empty())
|
out_str.clear();
|
||||||
{
|
}
|
||||||
result.push_back(out_str);
|
|
||||||
out_str.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out_str.append(1, c);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out_str.append(1, c);
|
||||||
|
}
|
||||||
result.push_back(out_str);
|
result.push_back(out_str);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> split_str(const std::string &s,
|
std::vector<std::string> split_str(const std::string& s,
|
||||||
const std::string &delim) {
|
const std::string& delim) {
|
||||||
std::vector<std::string> result;
|
std::vector<std::string> result;
|
||||||
std::size_t start = 0, delim_len = delim.size();
|
std::size_t start = 0, delim_len = delim.size();
|
||||||
while (true) {
|
while (true) {
|
||||||
std::size_t end = s.find(delim, start);
|
std::size_t end = s.find(delim, start);
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
if (start < s.size()) {
|
if (start < s.size()) {
|
||||||
result.push_back(s.substr(start));
|
result.push_back(s.substr(start));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (end > start) {
|
if (end > start) {
|
||||||
result.push_back(s.substr(start, end - start));
|
result.push_back(s.substr(start, end - start));
|
||||||
}
|
|
||||||
start = end + delim_len;
|
|
||||||
}
|
}
|
||||||
return result;
|
start = end + delim_len;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool prefix_compare(const PathTrie* x, const PathTrie* y) {
|
bool prefix_compare(const PathTrie* x, const PathTrie* y) {
|
||||||
if (x->_score == y->_score) {
|
if (x->score == y->score) {
|
||||||
if (x->_character == y->_character) {
|
if (x->character == y->character) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
|
||||||
return (x->_character < y->_character);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return x->_score > y->_score;
|
return (x->character < y->character);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
return x->score > y->score;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_word_to_fst(const std::vector<int>& word,
|
void add_word_to_fst(const std::vector<int>& word,
|
||||||
fst::StdVectorFst* dictionary) {
|
fst::StdVectorFst* dictionary) {
|
||||||
if (dictionary->NumStates() == 0) {
|
if (dictionary->NumStates() == 0) {
|
||||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||||
assert(start == 0);
|
assert(start == 0);
|
||||||
dictionary->SetStart(start);
|
dictionary->SetStart(start);
|
||||||
}
|
}
|
||||||
fst::StdVectorFst::StateId src = dictionary->Start();
|
fst::StdVectorFst::StateId src = dictionary->Start();
|
||||||
fst::StdVectorFst::StateId dst;
|
fst::StdVectorFst::StateId dst;
|
||||||
for (auto c : word) {
|
for (auto c : word) {
|
||||||
dst = dictionary->AddState();
|
dst = dictionary->AddState();
|
||||||
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
||||||
src = dst;
|
src = dst;
|
||||||
}
|
}
|
||||||
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool add_word_to_dictionary(const std::string& word,
|
bool add_word_to_dictionary(
|
||||||
const std::unordered_map<std::string, int>& char_map,
|
const std::string& word,
|
||||||
bool add_space,
|
const std::unordered_map<std::string, int>& char_map,
|
||||||
int SPACE_ID,
|
bool add_space,
|
||||||
fst::StdVectorFst* dictionary) {
|
int SPACE_ID,
|
||||||
auto characters = split_utf8_str(word);
|
fst::StdVectorFst* dictionary) {
|
||||||
|
auto characters = split_utf8_str(word);
|
||||||
|
|
||||||
std::vector<int> int_word;
|
std::vector<int> int_word;
|
||||||
|
|
||||||
for (auto& c : characters) {
|
for (auto& c : characters) {
|
||||||
if (c == " ") {
|
if (c == " ") {
|
||||||
int_word.push_back(SPACE_ID);
|
int_word.push_back(SPACE_ID);
|
||||||
} else {
|
} else {
|
||||||
auto int_c = char_map.find(c);
|
auto int_c = char_map.find(c);
|
||||||
if (int_c != char_map.end()) {
|
if (int_c != char_map.end()) {
|
||||||
int_word.push_back(int_c->second);
|
int_word.push_back(int_c->second);
|
||||||
} else {
|
} else {
|
||||||
return false; // return without adding
|
return false; // return without adding
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (add_space) {
|
if (add_space) {
|
||||||
int_word.push_back(SPACE_ID);
|
int_word.push_back(SPACE_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
add_word_to_fst(int_word, dictionary);
|
add_word_to_fst(int_word, dictionary);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1,59 +1,57 @@
|
|||||||
#ifndef PATH_TRIE_H
|
#ifndef PATH_TRIE_H
|
||||||
#define PATH_TRIE_H
|
#define PATH_TRIE_H
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <fst/fstlib.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fst/fstlib.h>
|
|
||||||
|
|
||||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
class PathTrie {
|
class PathTrie {
|
||||||
public:
|
public:
|
||||||
PathTrie();
|
PathTrie();
|
||||||
~PathTrie();
|
~PathTrie();
|
||||||
|
|
||||||
PathTrie* get_path_trie(int new_char, bool reset = true);
|
|
||||||
|
|
||||||
PathTrie* get_path_vec(std::vector<int> &output);
|
PathTrie* get_path_trie(int new_char, bool reset = true);
|
||||||
|
|
||||||
PathTrie* get_path_vec(std::vector<int>& output,
|
PathTrie* get_path_vec(std::vector<int>& output);
|
||||||
int stop,
|
|
||||||
size_t max_steps = std::numeric_limits<size_t>::max());
|
|
||||||
|
|
||||||
void iterate_to_vec(std::vector<PathTrie*> &output);
|
PathTrie* get_path_vec(std::vector<int>& output,
|
||||||
|
int stop,
|
||||||
|
size_t max_steps = std::numeric_limits<size_t>::max());
|
||||||
|
|
||||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||||
|
|
||||||
void set_matcher(std::shared_ptr<FSTMATCH> matcher);
|
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||||
|
|
||||||
bool is_empty() {
|
void set_matcher(std::shared_ptr<FSTMATCH> matcher);
|
||||||
return _ROOT == _character;
|
|
||||||
}
|
|
||||||
|
|
||||||
void remove();
|
bool is_empty() { return _ROOT == character; }
|
||||||
|
|
||||||
float _log_prob_b_prev;
|
void remove();
|
||||||
float _log_prob_nb_prev;
|
|
||||||
float _log_prob_b_cur;
|
|
||||||
float _log_prob_nb_cur;
|
|
||||||
float _score;
|
|
||||||
float _approx_ctc;
|
|
||||||
|
|
||||||
|
float log_prob_b_prev;
|
||||||
|
float log_prob_nb_prev;
|
||||||
|
float log_prob_b_cur;
|
||||||
|
float log_prob_nb_cur;
|
||||||
|
float score;
|
||||||
|
float approx_ctc;
|
||||||
|
int character;
|
||||||
|
PathTrie* parent;
|
||||||
|
|
||||||
int _ROOT;
|
private:
|
||||||
int _character;
|
int _ROOT;
|
||||||
bool _exists;
|
bool _exists;
|
||||||
|
|
||||||
PathTrie *_parent;
|
std::vector<std::pair<int, PathTrie*>> _children;
|
||||||
std::vector<std::pair<int, PathTrie*> > _children;
|
|
||||||
|
|
||||||
fst::StdVectorFst* _dictionary;
|
fst::StdVectorFst* _dictionary;
|
||||||
fst::StdVectorFst::StateId _dictionary_state;
|
fst::StdVectorFst::StateId _dictionary_state;
|
||||||
bool _has_dictionary;
|
bool _has_dictionary;
|
||||||
std::shared_ptr<FSTMATCH> _matcher;
|
std::shared_ptr<FSTMATCH> _matcher;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // PATH_TRIE_H
|
#endif // PATH_TRIE_H
|
||||||
|
@ -1,219 +1,208 @@
|
|||||||
#include <iostream>
|
#include "scorer.h"
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include "decoder_utils.h"
|
||||||
#include "lm/config.hh"
|
#include "lm/config.hh"
|
||||||
#include "lm/state.hh"
|
|
||||||
#include "lm/model.hh"
|
#include "lm/model.hh"
|
||||||
#include "util/tokenize_piece.hh"
|
#include "lm/state.hh"
|
||||||
#include "util/string_piece.hh"
|
#include "util/string_piece.hh"
|
||||||
#include "scorer.h"
|
#include "util/tokenize_piece.hh"
|
||||||
#include "decoder_utils.h"
|
|
||||||
|
|
||||||
using namespace lm::ngram;
|
using namespace lm::ngram;
|
||||||
|
|
||||||
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
|
Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
|
||||||
this->alpha = alpha;
|
this->alpha = alpha;
|
||||||
this->beta = beta;
|
this->beta = beta;
|
||||||
_is_character_based = true;
|
_is_character_based = true;
|
||||||
_language_model = nullptr;
|
_language_model = nullptr;
|
||||||
dictionary = nullptr;
|
dictionary = nullptr;
|
||||||
_max_order = 0;
|
_max_order = 0;
|
||||||
_SPACE_ID = -1;
|
_SPACE_ID = -1;
|
||||||
// load language model
|
// load language model
|
||||||
load_LM(lm_path.c_str());
|
load_LM(lm_path.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Scorer::~Scorer() {
|
Scorer::~Scorer() {
|
||||||
if (_language_model != nullptr)
|
if (_language_model != nullptr)
|
||||||
delete static_cast<lm::base::Model*>(_language_model);
|
delete static_cast<lm::base::Model*>(_language_model);
|
||||||
if (dictionary != nullptr)
|
if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||||
delete static_cast<fst::StdVectorFst*>(dictionary);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::load_LM(const char* filename) {
|
void Scorer::load_LM(const char* filename) {
|
||||||
if (access(filename, F_OK) != 0) {
|
if (access(filename, F_OK) != 0) {
|
||||||
std::cerr << "Invalid language model file !!!" << std::endl;
|
std::cerr << "Invalid language model file !!!" << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
RetriveStrEnumerateVocab enumerate;
|
RetriveStrEnumerateVocab enumerate;
|
||||||
lm::ngram::Config config;
|
lm::ngram::Config config;
|
||||||
config.enumerate_vocab = &enumerate;
|
config.enumerate_vocab = &enumerate;
|
||||||
_language_model = lm::ngram::LoadVirtual(filename, config);
|
_language_model = lm::ngram::LoadVirtual(filename, config);
|
||||||
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
|
_max_order = static_cast<lm::base::Model*>(_language_model)->Order();
|
||||||
_vocabulary = enumerate.vocabulary;
|
_vocabulary = enumerate.vocabulary;
|
||||||
for (size_t i = 0; i < _vocabulary.size(); ++i) {
|
for (size_t i = 0; i < _vocabulary.size(); ++i) {
|
||||||
if (_is_character_based
|
if (_is_character_based && _vocabulary[i] != UNK_TOKEN &&
|
||||||
&& _vocabulary[i] != UNK_TOKEN
|
_vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN &&
|
||||||
&& _vocabulary[i] != START_TOKEN
|
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||||
&& _vocabulary[i] != END_TOKEN
|
_is_character_based = false;
|
||||||
&& get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
|
||||||
_is_character_based = false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||||
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
|
lm::base::Model* model = static_cast<lm::base::Model*>(_language_model);
|
||||||
double cond_prob;
|
double cond_prob;
|
||||||
lm::ngram::State state, tmp_state, out_state;
|
lm::ngram::State state, tmp_state, out_state;
|
||||||
// avoid to inserting <s> in begin
|
// avoid to inserting <s> in begin
|
||||||
model->NullContextWrite(&state);
|
model->NullContextWrite(&state);
|
||||||
for (size_t i = 0; i < words.size(); ++i) {
|
for (size_t i = 0; i < words.size(); ++i) {
|
||||||
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||||
// encounter OOV
|
// encounter OOV
|
||||||
if (word_index == 0) {
|
if (word_index == 0) {
|
||||||
return OOV_SCORE;
|
return OOV_SCORE;
|
||||||
}
|
|
||||||
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
|
||||||
tmp_state = state;
|
|
||||||
state = out_state;
|
|
||||||
out_state = tmp_state;
|
|
||||||
}
|
}
|
||||||
// log10 prob
|
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||||
return cond_prob;
|
tmp_state = state;
|
||||||
|
state = out_state;
|
||||||
|
out_state = tmp_state;
|
||||||
|
}
|
||||||
|
// log10 prob
|
||||||
|
return cond_prob;
|
||||||
}
|
}
|
||||||
|
|
||||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||||
std::vector<std::string> sentence;
|
std::vector<std::string> sentence;
|
||||||
if (words.size() == 0) {
|
if (words.size() == 0) {
|
||||||
for (size_t i = 0; i < _max_order; ++i) {
|
for (size_t i = 0; i < _max_order; ++i) {
|
||||||
sentence.push_back(START_TOKEN);
|
sentence.push_back(START_TOKEN);
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < _max_order - 1; ++i) {
|
|
||||||
sentence.push_back(START_TOKEN);
|
|
||||||
}
|
|
||||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
|
||||||
}
|
}
|
||||||
sentence.push_back(END_TOKEN);
|
} else {
|
||||||
return get_log_prob(sentence);
|
for (size_t i = 0; i < _max_order - 1; ++i) {
|
||||||
|
sentence.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||||
|
}
|
||||||
|
sentence.push_back(END_TOKEN);
|
||||||
|
return get_log_prob(sentence);
|
||||||
}
|
}
|
||||||
|
|
||||||
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||||
assert(words.size() > _max_order);
|
assert(words.size() > _max_order);
|
||||||
double score = 0.0;
|
double score = 0.0;
|
||||||
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
|
for (size_t i = 0; i < words.size() - _max_order + 1; ++i) {
|
||||||
std::vector<std::string> ngram(words.begin() + i,
|
std::vector<std::string> ngram(words.begin() + i,
|
||||||
words.begin() + i + _max_order);
|
words.begin() + i + _max_order);
|
||||||
score += get_log_cond_prob(ngram);
|
score += get_log_cond_prob(ngram);
|
||||||
}
|
}
|
||||||
return score;
|
return score;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::reset_params(float alpha, float beta) {
|
void Scorer::reset_params(float alpha, float beta) {
|
||||||
this->alpha = alpha;
|
this->alpha = alpha;
|
||||||
this->beta = beta;
|
this->beta = beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Scorer::vec2str(const std::vector<int>& input) {
|
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||||
std::string word;
|
std::string word;
|
||||||
for (auto ind : input) {
|
for (auto ind : input) {
|
||||||
word += _char_list[ind];
|
word += _char_list[ind];
|
||||||
}
|
}
|
||||||
return word;
|
return word;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string>
|
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||||
Scorer::split_labels(const std::vector<int> &labels) {
|
if (labels.empty()) return {};
|
||||||
if (labels.empty())
|
|
||||||
return {};
|
std::string s = vec2str(labels);
|
||||||
|
std::vector<std::string> words;
|
||||||
std::string s = vec2str(labels);
|
if (_is_character_based) {
|
||||||
std::vector<std::string> words;
|
words = split_utf8_str(s);
|
||||||
if (_is_character_based) {
|
} else {
|
||||||
words = split_utf8_str(s);
|
words = split_str(s, " ");
|
||||||
} else {
|
}
|
||||||
words = split_str(s, " ");
|
return words;
|
||||||
}
|
|
||||||
return words;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::set_char_map(std::vector<std::string> char_list) {
|
void Scorer::set_char_map(std::vector<std::string> char_list) {
|
||||||
_char_list = char_list;
|
_char_list = char_list;
|
||||||
_char_map.clear();
|
_char_map.clear();
|
||||||
|
|
||||||
for(unsigned int i = 0; i < _char_list.size(); i++)
|
for (unsigned int i = 0; i < _char_list.size(); i++) {
|
||||||
{
|
if (_char_list[i] == " ") {
|
||||||
if (_char_list[i] == " ") {
|
_SPACE_ID = i;
|
||||||
_SPACE_ID = i;
|
_char_map[' '] = i;
|
||||||
_char_map[' '] = i;
|
} else if (_char_list[i].size() == 1) {
|
||||||
} else if(_char_list[i].size() == 1){
|
_char_map[_char_list[i][0]] = i;
|
||||||
_char_map[_char_list[i][0]] = i;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||||
std::vector<std::string> ngram;
|
std::vector<std::string> ngram;
|
||||||
PathTrie* current_node = prefix;
|
PathTrie* current_node = prefix;
|
||||||
PathTrie* new_node = nullptr;
|
PathTrie* new_node = nullptr;
|
||||||
|
|
||||||
for (int order = 0; order < _max_order; order++) {
|
|
||||||
std::vector<int> prefix_vec;
|
|
||||||
|
|
||||||
if (_is_character_based) {
|
|
||||||
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
|
|
||||||
current_node = new_node;
|
|
||||||
} else {
|
|
||||||
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
|
|
||||||
current_node = new_node->_parent; // Skipping spaces
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconstruct word
|
|
||||||
std::string word = vec2str(prefix_vec);
|
|
||||||
ngram.push_back(word);
|
|
||||||
|
|
||||||
if (new_node->_character == -1) {
|
|
||||||
// No more spaces, but still need order
|
|
||||||
for (int i = 0; i < _max_order - order - 1; i++) {
|
|
||||||
ngram.push_back(START_TOKEN);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::reverse(ngram.begin(), ngram.end());
|
|
||||||
return ngram;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Scorer::fill_dictionary(bool add_space) {
|
|
||||||
|
|
||||||
fst::StdVectorFst dictionary;
|
for (int order = 0; order < _max_order; order++) {
|
||||||
// First reverse char_list so ints can be accessed by chars
|
std::vector<int> prefix_vec;
|
||||||
std::unordered_map<std::string, int> char_map;
|
|
||||||
for (unsigned int i = 0; i < _char_list.size(); i++) {
|
|
||||||
char_map[_char_list[i]] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each unigram convert to ints and put in trie
|
if (_is_character_based) {
|
||||||
int vocab_size = 0;
|
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1);
|
||||||
for (const auto& word : _vocabulary) {
|
current_node = new_node;
|
||||||
bool added = add_word_to_dictionary(word,
|
} else {
|
||||||
char_map,
|
new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID);
|
||||||
add_space,
|
current_node = new_node->parent; // Skipping spaces
|
||||||
_SPACE_ID,
|
|
||||||
&dictionary);
|
|
||||||
vocab_size += added ? 1 : 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cerr << "Vocab Size " << vocab_size << std::endl;
|
// reconstruct word
|
||||||
|
std::string word = vec2str(prefix_vec);
|
||||||
// Simplify FST
|
ngram.push_back(word);
|
||||||
|
|
||||||
// This gets rid of "epsilon" transitions in the FST.
|
|
||||||
// These are transitions that don't require a string input to be taken.
|
|
||||||
// Getting rid of them is necessary to make the FST determinisitc, but
|
|
||||||
// can greatly increase the size of the FST
|
|
||||||
fst::RmEpsilon(&dictionary);
|
|
||||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
|
||||||
|
|
||||||
// This makes the FST deterministic, meaning for any string input there's
|
if (new_node->character == -1) {
|
||||||
// only one possible state the FST could be in. It is assumed our
|
// No more spaces, but still need order
|
||||||
// dictionary is deterministic when using it.
|
for (int i = 0; i < _max_order - order - 1; i++) {
|
||||||
// (lest we'd have to check for multiple transitions at each state)
|
ngram.push_back(START_TOKEN);
|
||||||
fst::Determinize(dictionary, new_dict);
|
}
|
||||||
|
break;
|
||||||
// Finds the simplest equivalent fst. This is unnecessary but decreases
|
}
|
||||||
// memory usage of the dictionary
|
}
|
||||||
fst::Minimize(new_dict);
|
std::reverse(ngram.begin(), ngram.end());
|
||||||
this->dictionary = new_dict;
|
return ngram;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::fill_dictionary(bool add_space) {
|
||||||
|
fst::StdVectorFst dictionary;
|
||||||
|
// First reverse char_list so ints can be accessed by chars
|
||||||
|
std::unordered_map<std::string, int> char_map;
|
||||||
|
for (unsigned int i = 0; i < _char_list.size(); i++) {
|
||||||
|
char_map[_char_list[i]] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each unigram convert to ints and put in trie
|
||||||
|
int vocab_size = 0;
|
||||||
|
for (const auto& word : _vocabulary) {
|
||||||
|
bool added = add_word_to_dictionary(
|
||||||
|
word, char_map, add_space, _SPACE_ID, &dictionary);
|
||||||
|
vocab_size += added ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cerr << "Vocab Size " << vocab_size << std::endl;
|
||||||
|
|
||||||
|
// Simplify FST
|
||||||
|
|
||||||
|
// This gets rid of "epsilon" transitions in the FST.
|
||||||
|
// These are transitions that don't require a string input to be taken.
|
||||||
|
// Getting rid of them is necessary to make the FST determinisitc, but
|
||||||
|
// can greatly increase the size of the FST
|
||||||
|
fst::RmEpsilon(&dictionary);
|
||||||
|
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||||
|
|
||||||
|
// This makes the FST deterministic, meaning for any string input there's
|
||||||
|
// only one possible state the FST could be in. It is assumed our
|
||||||
|
// dictionary is deterministic when using it.
|
||||||
|
// (lest we'd have to check for multiple transitions at each state)
|
||||||
|
fst::Determinize(dictionary, new_dict);
|
||||||
|
|
||||||
|
// Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||||
|
// memory usage of the dictionary
|
||||||
|
fst::Minimize(new_dict);
|
||||||
|
this->dictionary = new_dict;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in new issue