diff --git a/deploy/ctc_decoders.cpp b/deploy/ctc_decoders.cpp index 4e94edfb..cedb943e 100644 --- a/deploy/ctc_decoders.cpp +++ b/deploy/ctc_decoders.cpp @@ -1,337 +1,329 @@ -#include -#include +#include "ctc_decoders.h" #include -#include #include +#include #include -#include "fst/fstlib.h" -#include "ctc_decoders.h" +#include +#include +#include "ThreadPool.h" #include "decoder_utils.h" +#include "fst/fstlib.h" #include "path_trie.h" -#include "ThreadPool.h" -std::string ctc_best_path_decoder(std::vector > probs_seq, - std::vector vocabulary) -{ - // dimension check - int num_time_steps = probs_seq.size(); - for (int i=0; i> probs_seq, + std::vector vocabulary) { + // dimension check + int num_time_steps = probs_seq.size(); + for (int i = 0; i < num_time_steps; i++) { + if (probs_seq[i].size() != vocabulary.size() + 1) { + std::cout << "The shape of probs_seq does not match" + << " with the shape of the vocabulary!" << std::endl; + exit(1); } - - int blank_id = vocabulary.size(); - - std::vector max_idx_vec; - double max_prob = 0.0; - int max_idx = 0; - for (int i = 0; i < num_time_steps; i++) { - for (int j = 0; j < probs_seq[i].size(); j++) { - if (max_prob < probs_seq[i][j]) { - max_idx = j; - max_prob = probs_seq[i][j]; - } - } - max_idx_vec.push_back(max_idx); - max_prob = 0.0; - max_idx = 0; + } + + int blank_id = vocabulary.size(); + + std::vector max_idx_vec; + double max_prob = 0.0; + int max_idx = 0; + for (int i = 0; i < num_time_steps; i++) { + for (int j = 0; j < probs_seq[i].size(); j++) { + if (max_prob < probs_seq[i][j]) { + max_idx = j; + max_prob = probs_seq[i][j]; + } } - - std::vector idx_vec; - for (int i = 0; i < max_idx_vec.size(); i++) { - if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i-1])) { - idx_vec.push_back(max_idx_vec[i]); - } + max_idx_vec.push_back(max_idx); + max_prob = 0.0; + max_idx = 0; + } + + std::vector idx_vec; + for (int i = 0; i < max_idx_vec.size(); i++) { + if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { + idx_vec.push_back(max_idx_vec[i]); } + } - std::string best_path_result; - for (int i = 0; i < idx_vec.size(); i++) { - if (idx_vec[i] != blank_id) { - best_path_result += vocabulary[idx_vec[i]]; - } + std::string best_path_result; + for (int i = 0; i < idx_vec.size(); i++) { + if (idx_vec[i] != blank_id) { + best_path_result += vocabulary[idx_vec[i]]; } - return best_path_result; + } + return best_path_result; } -std::vector > - ctc_beam_search_decoder(std::vector > probs_seq, - int beam_size, - std::vector vocabulary, - int blank_id, - double cutoff_prob, - int cutoff_top_n, - Scorer *ext_scorer) -{ - // dimension check - int num_time_steps = probs_seq.size(); - for (int i = 0; i < num_time_steps; i++) { - if (probs_seq[i].size() != vocabulary.size() + 1) { - std::cout << " The shape of probs_seq does not match" - << " with the shape of the vocabulary!" << std::endl; - exit(1); - } +std::vector> ctc_beam_search_decoder( + std::vector> probs_seq, + int beam_size, + std::vector vocabulary, + int blank_id, + double cutoff_prob, + int cutoff_top_n, + Scorer *extscorer) { + // dimension check + int num_time_steps = probs_seq.size(); + for (int i = 0; i < num_time_steps; i++) { + if (probs_seq[i].size() != vocabulary.size() + 1) { + std::cout << " The shape of probs_seq does not match" + << " with the shape of the vocabulary!" << std::endl; + exit(1); } - - // blank_id check - if (blank_id > vocabulary.size()) { - std::cout << " Invalid blank_id! " << std::endl; - exit(1); + } + + // blank_id check + if (blank_id > vocabulary.size()) { + std::cout << " Invalid blank_id! " << std::endl; + exit(1); + } + + // assign space ID + std::vector::iterator it = + std::find(vocabulary.begin(), vocabulary.end(), " "); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if (space_id >= vocabulary.size()) { + space_id = -2; + } + + // init prefixes' root + PathTrie root; + root.score = root.log_prob_b_prev = 0.0; + std::vector prefixes; + prefixes.push_back(&root); + + if (extscorer != nullptr) { + if (extscorer->is_char_map_empty()) { + extscorer->set_char_map(vocabulary); } - - // assign space ID - std::vector::iterator it = std::find(vocabulary.begin(), - vocabulary.end(), " "); - int space_id = it - vocabulary.begin(); - // if no space in vocabulary - if(space_id >= vocabulary.size()) { - space_id = -2; + if (!extscorer->is_character_based()) { + if (extscorer->dictionary == nullptr) { + // fill dictionary for fst + extscorer->fill_dictionary(true); + } + auto fst_dict = static_cast(extscorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root.set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + } + + // prefix search over time + for (int time_step = 0; time_step < num_time_steps; time_step++) { + std::vector prob = probs_seq[time_step]; + std::vector> prob_idx; + for (int i = 0; i < prob.size(); i++) { + prob_idx.push_back(std::pair(i, prob[i])); } - // init prefixes' root - PathTrie root; - root._score = root._log_prob_b_prev = 0.0; - std::vector prefixes; - prefixes.push_back(&root); + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (extscorer != nullptr) { + int num_prefixes = std::min((int)prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + log(prob[blank_id]) - + std::max(0.0, extscorer->beta); + full_beam = (num_prefixes == beam_size); + } - if ( ext_scorer != nullptr) { - if (ext_scorer->is_char_map_empty()) { - ext_scorer->set_char_map(vocabulary); - } - if (!ext_scorer->is_character_based()) { - if (ext_scorer->dictionary == nullptr) { - // fill dictionary for fst - ext_scorer->fill_dictionary(true); - } - auto fst_dict = static_cast - (ext_scorer->dictionary); - fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared - (*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); + // pruning of vacobulary + int cutoff_len = prob.size(); + if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { + std::sort( + prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); + if (cutoff_prob < 1.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (int i = 0; i < prob_idx.size(); i++) { + cum_prob += prob_idx[i].second; + cutoff_len += 1; + if (cum_prob >= cutoff_prob) break; } + } + cutoff_len = std::min(cutoff_len, cutoff_top_n); + prob_idx = std::vector>( + prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (int i = 0; i < cutoff_len; i++) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); } - // prefix search over time - for (int time_step = 0; time_step < num_time_steps; time_step++) { - std::vector prob = probs_seq[time_step]; - std::vector > prob_idx; - for (int i=0; i(i, prob[i])); - } + // loop over chars + for (int index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + float log_prob_c = log_prob_idx[index].second; - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - int num_prefixes = std::min((int)prefixes.size(), beam_size); - std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, - prefix_compare); - min_cutoff = prefixes[num_prefixes-1]->_score + log(prob[blank_id]) - - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } - - // pruning of vacobulary - int cutoff_len = prob.size(); - if (cutoff_prob < 1.0 || cutoff_top_n < prob.size()) { - std::sort(prob_idx.begin(), - prob_idx.end(), - pair_comp_second_rev); - if (cutoff_prob < 1.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (int i=0; i= cutoff_prob) break; - } - } - cutoff_len = std::min(cutoff_len, cutoff_top_n); - prob_idx = std::vector >( prob_idx.begin(), - prob_idx.begin() + cutoff_len); - } - std::vector > log_prob_idx; - for (int i = 0; i < cutoff_len; i++) { - log_prob_idx.push_back(std::pair - (prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); - } + for (int i = 0; i < prefixes.size() && i < beam_size; i++) { + auto prefix = prefixes[i]; - // loop over chars - for (int index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - float log_prob_c = log_prob_idx[index].second; - - for (int i = 0; i < prefixes.size() && i_score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->_log_prob_b_cur = log_sum_exp( - prefix->_log_prob_b_cur, - log_prob_c + prefix->_score); - continue; - } - // repeated character - if (c == prefix->_character) { - prefix->_log_prob_nb_cur = log_sum_exp( - prefix->_log_prob_nb_cur, - log_prob_c + prefix->_log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->_character - && prefix->_log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->_log_prob_b_prev; - } else if (c != prefix->_character) { - log_p = log_prob_c + prefix->_score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based()) ) { - PathTrie *prefix_to_score = nullptr; - - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - - double score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * - ext_scorer->alpha; - - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->_log_prob_nb_cur = log_sum_exp( - prefix_new->_log_prob_nb_cur, log_p); - } - } // end of loop over prefix - } // end of loop over chars - - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); - - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - - for (size_t i = beam_size; i < prefixes.size(); i++) { - prefixes[i]->remove(); - } + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; } - } // end of loop over time - - // compute aproximate ctc score as the return score - for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { - double approx_ctc = prefixes[i]->_score; - - if (ext_scorer != nullptr) { - std::vector output; - prefixes[i]->get_path_vec(output); - size_t prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= (ext_scorer->get_sent_log_prob(words)) - * ext_scorer->alpha; + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = + log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = log_sum_exp( + prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (extscorer != nullptr && + (c == space_id || extscorer->is_character_based())) { + PathTrie *prefix_toscore = nullptr; + + // skip scoring the space + if (extscorer->is_character_based()) { + prefix_toscore = prefix_new; + } else { + prefix_toscore = prefix; + } - prefixes[i]->_approx_ctc = approx_ctc; - } + double score = 0.0; + std::vector ngram; + ngram = extscorer->make_ngram(prefix_toscore); + score = extscorer->get_log_cond_prob(ngram) * extscorer->alpha; - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i< prefixes.size(); i++) { - space_prefixes.push_back(prefixes[i]); + log_p += score; + log_p += extscorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); } + } // end of loop over prefix + } // end of loop over chars + + prefixes.clear(); + // update log probs + root.iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + + for (size_t i = beam_size; i < prefixes.size(); i++) { + prefixes[i]->remove(); + } } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector > output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { - std::vector output; - space_prefixes[i]->get_path_vec(output); - // convert index to string - std::string output_str; - for (int j = 0; j < output.size(); j++) { - output_str += vocabulary[output[j]]; - } - std::pair - output_pair(-space_prefixes[i]->_approx_ctc, output_str); - output_vecs.emplace_back(output_pair); + } // end of loop over time + + // compute aproximate ctc score as the return score + for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { + double approx_ctc = prefixes[i]->score; + + if (extscorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + size_t prefix_length = output.size(); + auto words = extscorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * extscorer->beta; + // remove language model weight: + approx_ctc -= (extscorer->get_sent_log_prob(words)) * extscorer->alpha; } - return output_vecs; - } - - -std::vector > > - ctc_beam_search_decoder_batch( - std::vector>> probs_split, - int beam_size, - std::vector vocabulary, - int blank_id, - int num_processes, - double cutoff_prob, - int cutoff_top_n, - Scorer *ext_scorer - ) { - if (num_processes <= 0) { - std::cout << "num_processes must be nonnegative!" << std::endl; - exit(1); + prefixes[i]->approx_ctc = approx_ctc; + } + + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { + space_prefixes.push_back(prefixes[i]); } - // thread pool - ThreadPool pool(num_processes); - // number of samples - int batch_size = probs_split.size(); - - // scorer filling up - if ( ext_scorer != nullptr) { - if (ext_scorer->is_char_map_empty()) { - ext_scorer->set_char_map(vocabulary); - } - if(!ext_scorer->is_character_based() - && ext_scorer->dictionary == nullptr) { - // init dictionary - ext_scorer->fill_dictionary(true); - } + } + + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { + std::vector output; + space_prefixes[i]->get_path_vec(output); + // convert index to string + std::string output_str; + for (int j = 0; j < output.size(); j++) { + output_str += vocabulary[output[j]]; } + std::pair output_pair(-space_prefixes[i]->approx_ctc, + output_str); + output_vecs.emplace_back(output_pair); + } - // enqueue the tasks of decoding - std::vector>>> res; - for (int i = 0; i < batch_size; i++) { - res.emplace_back( - pool.enqueue(ctc_beam_search_decoder, probs_split[i], - beam_size, vocabulary, blank_id, cutoff_prob, - cutoff_top_n, ext_scorer) - ); - } + return output_vecs; +} - // get decoding results - std::vector > > batch_results; - for (int i = 0; i < batch_size; i++) { - batch_results.emplace_back(res[i].get()); +std::vector>> +ctc_beam_search_decoder_batch( + std::vector>> probs_split, + int beam_size, + std::vector vocabulary, + int blank_id, + int num_processes, + double cutoff_prob, + int cutoff_top_n, + Scorer *extscorer) { + if (num_processes <= 0) { + std::cout << "num_processes must be nonnegative!" << std::endl; + exit(1); + } + // thread pool + ThreadPool pool(num_processes); + // number of samples + int batch_size = probs_split.size(); + + // scorer filling up + if (extscorer != nullptr) { + if (extscorer->is_char_map_empty()) { + extscorer->set_char_map(vocabulary); + } + if (!extscorer->is_character_based() && + extscorer->dictionary == nullptr) { + // init dictionary + extscorer->fill_dictionary(true); } - return batch_results; + } + + // enqueue the tasks of decoding + std::vector>>> res; + for (int i = 0; i < batch_size; i++) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + probs_split[i], + beam_size, + vocabulary, + blank_id, + cutoff_prob, + cutoff_top_n, + extscorer)); + } + + // get decoding results + std::vector>> batch_results; + for (int i = 0; i < batch_size; i++) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; } diff --git a/deploy/ctc_decoders.h b/deploy/ctc_decoders.h index 58d2b789..78edefb7 100644 --- a/deploy/ctc_decoders.h +++ b/deploy/ctc_decoders.h @@ -1,9 +1,9 @@ #ifndef CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_ -#include #include #include +#include #include "scorer.h" /* CTC Best Path Decoder @@ -16,8 +16,8 @@ * A vector that each element is a pair of score and decoding result, * in desending order. */ -std::string ctc_best_path_decoder(std::vector > probs_seq, - std::vector vocabulary); +std::string ctc_best_path_decoder(std::vector> probs_seq, + std::vector vocabulary); /* CTC Beam Search Decoder @@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector > probs_seq, * A vector that each element is a pair of score and decoding result, * in desending order. */ -std::vector > - ctc_beam_search_decoder(std::vector > probs_seq, - int beam_size, - std::vector vocabulary, - int blank_id, - double cutoff_prob=1.0, - int cutoff_top_n=40, - Scorer *ext_scorer=NULL - ); +std::vector> ctc_beam_search_decoder( + std::vector> probs_seq, + int beam_size, + std::vector vocabulary, + int blank_id, + double cutoff_prob = 1.0, + int cutoff_top_n = 40, + Scorer *ext_scorer = NULL); /* CTC Beam Search Decoder for batch data, the interface is consistent with the * original decoder in Python version. @@ -63,15 +62,14 @@ std::vector > * sample. */ std::vector>> - ctc_beam_search_decoder_batch(std::vector>> probs_split, - int beam_size, - std::vector vocabulary, - int blank_id, - int num_processes, - double cutoff_prob=1.0, - int cutoff_top_n=40, - Scorer *ext_scorer=NULL - ); - +ctc_beam_search_decoder_batch( + std::vector>> probs_split, + int beam_size, + std::vector vocabulary, + int blank_id, + int num_processes, + double cutoff_prob = 1.0, + int cutoff_top_n = 40, + Scorer *ext_scorer = NULL); -#endif // CTC_BEAM_SEARCH_DECODER_H_ +#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/deploy/decoder_utils.cpp b/deploy/decoder_utils.cpp index 37674f71..bed0f623 100644 --- a/deploy/decoder_utils.cpp +++ b/deploy/decoder_utils.cpp @@ -1,113 +1,111 @@ -#include +#include "decoder_utils.h" #include #include -#include "decoder_utils.h" +#include size_t get_utf8_str_len(const std::string& str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; } -std::vector split_utf8_str(const std::string& str) -{ +std::vector split_utf8_str(const std::string& str) { std::vector result; std::string out_str; - for (char c : str) + for (char c : str) { + if ((c & 0xc0) != 0x80) // new UTF-8 character { - if ((c & 0xc0) != 0x80) //new UTF-8 character - { - if (!out_str.empty()) - { - result.push_back(out_str); - out_str.clear(); - } - } - - out_str.append(1, c); + if (!out_str.empty()) { + result.push_back(out_str); + out_str.clear(); + } } + + out_str.append(1, c); + } result.push_back(out_str); return result; } -std::vector split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; +std::vector split_str(const std::string& s, + const std::string& delim) { + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); } - return result; + start = end + delim_len; + } + return result; } -bool prefix_compare(const PathTrie* x, const PathTrie* y) { - if (x->_score == y->_score) { - if (x->_character == y->_character) { - return false; - } else { - return (x->_character < y->_character); - } +bool prefix_compare(const PathTrie* x, const PathTrie* y) { + if (x->score == y->score) { + if (x->character == y->character) { + return false; } else { - return x->_score > y->_score; + return (x->character < y->character); } + } else { + return x->score > y->score; + } } void add_word_to_fst(const std::vector& word, fst::StdVectorFst* dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); + if (dictionary->NumStates() == 0) { + fst::StdVectorFst::StateId start = dictionary->AddState(); + assert(start == 0); + dictionary->SetStart(start); + } + fst::StdVectorFst::StateId src = dictionary->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = dictionary->AddState(); + dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + dictionary->SetFinal(dst, fst::StdArc::Weight::One()); } -bool add_word_to_dictionary(const std::string& word, - const std::unordered_map& char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst* dictionary) { - auto characters = split_utf8_str(word); +bool add_word_to_dictionary( + const std::string& word, + const std::unordered_map& char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst* dictionary) { + auto characters = split_utf8_str(word); - std::vector int_word; + std::vector int_word; - for (auto& c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } - } + for (auto& c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without adding + } } + } - if (add_space) { - int_word.push_back(SPACE_ID); - } + if (add_space) { + int_word.push_back(SPACE_ID); + } - add_word_to_fst(int_word, dictionary); - return true; + add_word_to_fst(int_word, dictionary); + return true; } diff --git a/deploy/decoder_utils.h b/deploy/decoder_utils.h index 829ea76d..51985c86 100644 --- a/deploy/decoder_utils.h +++ b/deploy/decoder_utils.h @@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits::min(); // Function template for comparing two pairs template bool pair_comp_first_rev(const std::pair &a, - const std::pair &b) -{ - return a.first > b.first; + const std::pair &b) { + return a.first > b.first; } template bool pair_comp_second_rev(const std::pair &a, - const std::pair &b) -{ - return a.second > b.second; + const std::pair &b) { + return a.second > b.second; } template -T log_sum_exp(const T &x, const T &y) -{ - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x-xmax) + std::exp(y-xmax)) + xmax; +T log_sum_exp(const T &x, const T &y) { + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) return y; + if (y <= num_min) return x; + T xmax = std::max(x, y); + return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; } // Functor for prefix comparsion -bool prefix_compare(const PathTrie* x, const PathTrie* y); +bool prefix_compare(const PathTrie *x, const PathTrie *y); // Get length of utf8 encoding string // See: http://stackoverflow.com/a/4063229 -size_t get_utf8_str_len(const std::string& str); +size_t get_utf8_str_len(const std::string &str); // Split a string into a list of strings on a given string // delimiter. NB: delimiters on beginning / end of string are @@ -50,13 +47,14 @@ std::vector split_str(const std::string &s, std::vector split_utf8_str(const std::string &str); // Add a word in index to the dicionary of fst -void add_word_to_fst(const std::vector& word, - fst::StdVectorFst* dictionary); +void add_word_to_fst(const std::vector &word, + fst::StdVectorFst *dictionary); // Add a word in string to dictionary -bool add_word_to_dictionary(const std::string& word, - const std::unordered_map& char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst* dictionary); -#endif // DECODER_UTILS_H +bool add_word_to_dictionary( + const std::string &word, + const std::unordered_map &char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst *dictionary); +#endif // DECODER_UTILS_H diff --git a/deploy/path_trie.cpp b/deploy/path_trie.cpp index b22f2a47..db0b20cb 100644 --- a/deploy/path_trie.cpp +++ b/deploy/path_trie.cpp @@ -4,145 +4,142 @@ #include #include -#include "path_trie.h" #include "decoder_utils.h" +#include "path_trie.h" PathTrie::PathTrie() { - _log_prob_b_prev = -NUM_FLT_INF; - _log_prob_nb_prev = -NUM_FLT_INF; - _log_prob_b_cur = -NUM_FLT_INF; - _log_prob_nb_cur = -NUM_FLT_INF; - _score = -NUM_FLT_INF; - - _ROOT = -1; - _character = _ROOT; - _exists = true; - _parent = nullptr; - _dictionary = nullptr; - _dictionary_state = 0; - _has_dictionary = false; - _matcher = nullptr; // finds arcs in FST + log_prob_b_prev = -NUM_FLT_INF; + log_prob_nb_prev = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + score = -NUM_FLT_INF; + + _ROOT = -1; + character = _ROOT; + _exists = true; + parent = nullptr; + _dictionary = nullptr; + _dictionary_state = 0; + _has_dictionary = false; + _matcher = nullptr; // finds arcs in FST } PathTrie::~PathTrie() { - for (auto child : _children) { - delete child.second; - } + for (auto child : _children) { + delete child.second; + } } PathTrie* PathTrie::get_path_trie(int new_char, bool reset) { - auto child = _children.begin(); - for (child = _children.begin(); child != _children.end(); ++child) { - if (child->first == new_char) { - break; - } + auto child = _children.begin(); + for (child = _children.begin(); child != _children.end(); ++child) { + if (child->first == new_char) { + break; } - if ( child != _children.end() ) { - if (!child->second->_exists) { - child->second->_exists = true; - child->second->_log_prob_b_prev = -NUM_FLT_INF; - child->second->_log_prob_nb_prev = -NUM_FLT_INF; - child->second->_log_prob_b_cur = -NUM_FLT_INF; - child->second->_log_prob_nb_cur = -NUM_FLT_INF; + } + if (child != _children.end()) { + if (!child->second->_exists) { + child->second->_exists = true; + child->second->log_prob_b_prev = -NUM_FLT_INF; + child->second->log_prob_nb_prev = -NUM_FLT_INF; + child->second->log_prob_b_cur = -NUM_FLT_INF; + child->second->log_prob_nb_cur = -NUM_FLT_INF; + } + return (child->second); + } else { + if (_has_dictionary) { + _matcher->SetState(_dictionary_state); + bool found = _matcher->Find(new_char); + if (!found) { + // Adding this character causes word outside dictionary + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = _dictionary->Final(_dictionary_state); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + _dictionary_state = _dictionary->Start(); } - return (child->second); + return nullptr; + } else { + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + new_path->_dictionary = _dictionary; + new_path->_dictionary_state = _matcher->Value().nextstate; + new_path->_has_dictionary = true; + new_path->_matcher = _matcher; + _children.push_back(std::make_pair(new_char, new_path)); + return new_path; + } } else { - if (_has_dictionary) { - _matcher->SetState(_dictionary_state); - bool found = _matcher->Find(new_char); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = _dictionary->Final(_dictionary_state); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - _dictionary_state = _dictionary->Start(); - } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->_character = new_char; - new_path->_parent = this; - new_path->_dictionary = _dictionary; - new_path->_dictionary_state = _matcher->Value().nextstate; - new_path->_has_dictionary = true; - new_path->_matcher = _matcher; - _children.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } else { - PathTrie* new_path = new PathTrie; - new_path->_character = new_char; - new_path->_parent = this; - _children.push_back(std::make_pair(new_char, new_path)); - return new_path; - } + PathTrie* new_path = new PathTrie; + new_path->character = new_char; + new_path->parent = this; + _children.push_back(std::make_pair(new_char, new_path)); + return new_path; } + } } PathTrie* PathTrie::get_path_vec(std::vector& output) { - return get_path_vec(output, _ROOT); + return get_path_vec(output, _ROOT); } PathTrie* PathTrie::get_path_vec(std::vector& output, - int stop, - size_t max_steps) { - if (_character == stop || - _character == _ROOT || - output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - return this; - } else { - output.push_back(_character); - return _parent->get_path_vec(output, stop, max_steps); - } + int stop, + size_t max_steps) { + if (character == stop || character == _ROOT || output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + return this; + } else { + output.push_back(character); + return parent->get_path_vec(output, stop, max_steps); + } } -void PathTrie::iterate_to_vec( - std::vector& output) { - if (_exists) { - _log_prob_b_prev = _log_prob_b_cur; - _log_prob_nb_prev = _log_prob_nb_cur; +void PathTrie::iterate_to_vec(std::vector& output) { + if (_exists) { + log_prob_b_prev = log_prob_b_cur; + log_prob_nb_prev = log_prob_nb_cur; - _log_prob_b_cur = -NUM_FLT_INF; - _log_prob_nb_cur = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; - _score = log_sum_exp(_log_prob_b_prev, _log_prob_nb_prev); - output.push_back(this); - } - for (auto child : _children) { - child.second->iterate_to_vec(output); - } + score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + output.push_back(this); + } + for (auto child : _children) { + child.second->iterate_to_vec(output); + } } void PathTrie::remove() { - _exists = false; - - if (_children.size() == 0) { - auto child = _parent->_children.begin(); - for (child = _parent->_children.begin(); - child != _parent->_children.end(); ++child) { - if (child->first == _character) { - _parent->_children.erase(child); - break; - } - } - - if ( _parent->_children.size() == 0 && !_parent->_exists ) { - _parent->remove(); - } + _exists = false; + + if (_children.size() == 0) { + auto child = parent->_children.begin(); + for (child = parent->_children.begin(); child != parent->_children.end(); + ++child) { + if (child->first == character) { + parent->_children.erase(child); + break; + } + } - delete this; + if (parent->_children.size() == 0 && !parent->_exists) { + parent->remove(); } + + delete this; + } } void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - _dictionary = dictionary; - _dictionary_state = dictionary->Start(); - _has_dictionary = true; + _dictionary = dictionary; + _dictionary_state = dictionary->Start(); + _has_dictionary = true; } using FSTMATCH = fst::SortedMatcher; void PathTrie::set_matcher(std::shared_ptr matcher) { - _matcher = matcher; + _matcher = matcher; } diff --git a/deploy/path_trie.h b/deploy/path_trie.h index 7b378e3f..cac524a3 100644 --- a/deploy/path_trie.h +++ b/deploy/path_trie.h @@ -1,59 +1,57 @@ #ifndef PATH_TRIE_H #define PATH_TRIE_H #pragma once +#include #include #include #include #include #include -#include using FSTMATCH = fst::SortedMatcher; class PathTrie { public: - PathTrie(); - ~PathTrie(); - - PathTrie* get_path_trie(int new_char, bool reset = true); + PathTrie(); + ~PathTrie(); - PathTrie* get_path_vec(std::vector &output); + PathTrie* get_path_trie(int new_char, bool reset = true); - PathTrie* get_path_vec(std::vector& output, - int stop, - size_t max_steps = std::numeric_limits::max()); + PathTrie* get_path_vec(std::vector& output); - void iterate_to_vec(std::vector &output); + PathTrie* get_path_vec(std::vector& output, + int stop, + size_t max_steps = std::numeric_limits::max()); - void set_dictionary(fst::StdVectorFst* dictionary); + void iterate_to_vec(std::vector& output); - void set_matcher(std::shared_ptr matcher); + void set_dictionary(fst::StdVectorFst* dictionary); - bool is_empty() { - return _ROOT == _character; - } + void set_matcher(std::shared_ptr matcher); - void remove(); + bool is_empty() { return _ROOT == character; } - float _log_prob_b_prev; - float _log_prob_nb_prev; - float _log_prob_b_cur; - float _log_prob_nb_cur; - float _score; - float _approx_ctc; + void remove(); + float log_prob_b_prev; + float log_prob_nb_prev; + float log_prob_b_cur; + float log_prob_nb_cur; + float score; + float approx_ctc; + int character; + PathTrie* parent; - int _ROOT; - int _character; - bool _exists; +private: + int _ROOT; + bool _exists; - PathTrie *_parent; - std::vector > _children; + std::vector> _children; - fst::StdVectorFst* _dictionary; - fst::StdVectorFst::StateId _dictionary_state; - bool _has_dictionary; - std::shared_ptr _matcher; + fst::StdVectorFst* _dictionary; + fst::StdVectorFst::StateId _dictionary_state; + bool _has_dictionary; + std::shared_ptr _matcher; }; -#endif // PATH_TRIE_H +#endif // PATH_TRIE_H diff --git a/deploy/scorer.cpp b/deploy/scorer.cpp index ced71995..8651eb61 100644 --- a/deploy/scorer.cpp +++ b/deploy/scorer.cpp @@ -1,219 +1,208 @@ -#include +#include "scorer.h" #include +#include +#include "decoder_utils.h" #include "lm/config.hh" -#include "lm/state.hh" #include "lm/model.hh" -#include "util/tokenize_piece.hh" +#include "lm/state.hh" #include "util/string_piece.hh" -#include "scorer.h" -#include "decoder_utils.h" +#include "util/tokenize_piece.hh" using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path) { - this->alpha = alpha; - this->beta = beta; - _is_character_based = true; - _language_model = nullptr; - dictionary = nullptr; - _max_order = 0; - _SPACE_ID = -1; - // load language model - load_LM(lm_path.c_str()); + this->alpha = alpha; + this->beta = beta; + _is_character_based = true; + _language_model = nullptr; + dictionary = nullptr; + _max_order = 0; + _SPACE_ID = -1; + // load language model + load_LM(lm_path.c_str()); } Scorer::~Scorer() { - if (_language_model != nullptr) - delete static_cast(_language_model); - if (dictionary != nullptr) - delete static_cast(dictionary); + if (_language_model != nullptr) + delete static_cast(_language_model); + if (dictionary != nullptr) delete static_cast(dictionary); } void Scorer::load_LM(const char* filename) { - if (access(filename, F_OK) != 0) { - std::cerr << "Invalid language model file !!!" << std::endl; - exit(1); - } - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - _language_model = lm::ngram::LoadVirtual(filename, config); - _max_order = static_cast(_language_model)->Order(); - _vocabulary = enumerate.vocabulary; - for (size_t i = 0; i < _vocabulary.size(); ++i) { - if (_is_character_based - && _vocabulary[i] != UNK_TOKEN - && _vocabulary[i] != START_TOKEN - && _vocabulary[i] != END_TOKEN - && get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - _is_character_based = false; - } + if (access(filename, F_OK) != 0) { + std::cerr << "Invalid language model file !!!" << std::endl; + exit(1); + } + RetriveStrEnumerateVocab enumerate; + lm::ngram::Config config; + config.enumerate_vocab = &enumerate; + _language_model = lm::ngram::LoadVirtual(filename, config); + _max_order = static_cast(_language_model)->Order(); + _vocabulary = enumerate.vocabulary; + for (size_t i = 0; i < _vocabulary.size(); ++i) { + if (_is_character_based && _vocabulary[i] != UNK_TOKEN && + _vocabulary[i] != START_TOKEN && _vocabulary[i] != END_TOKEN && + get_utf8_str_len(enumerate.vocabulary[i]) > 1) { + _is_character_based = false; } + } } double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(_language_model); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; - } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; + lm::base::Model* model = static_cast(_language_model); + double cond_prob; + lm::ngram::State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); + // encounter OOV + if (word_index == 0) { + return OOV_SCORE; } - // log10 prob - return cond_prob; + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; + state = out_state; + out_state = tmp_state; + } + // log10 prob + return cond_prob; } double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < _max_order; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < _max_order - 1; ++i) { - sentence.push_back(START_TOKEN); - } - sentence.insert(sentence.end(), words.begin(), words.end()); + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < _max_order; ++i) { + sentence.push_back(START_TOKEN); } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); + } else { + for (size_t i = 0; i < _max_order - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); + } + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); } double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > _max_order); - double score = 0.0; - for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + _max_order); - score += get_log_cond_prob(ngram); - } - return score; + assert(words.size() > _max_order); + double score = 0.0; + for (size_t i = 0; i < words.size() - _max_order + 1; ++i) { + std::vector ngram(words.begin() + i, + words.begin() + i + _max_order); + score += get_log_cond_prob(ngram); + } + return score; } void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; + this->alpha = alpha; + this->beta = beta; } std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += _char_list[ind]; - } - return word; + std::string word; + for (auto ind : input) { + word += _char_list[ind]; + } + return word; } -std::vector -Scorer::split_labels(const std::vector &labels) { - if (labels.empty()) - return {}; - - std::string s = vec2str(labels); - std::vector words; - if (_is_character_based) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; +std::vector Scorer::split_labels(const std::vector& labels) { + if (labels.empty()) return {}; + + std::string s = vec2str(labels); + std::vector words; + if (_is_character_based) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } + return words; } void Scorer::set_char_map(std::vector char_list) { - _char_list = char_list; - _char_map.clear(); - - for(unsigned int i = 0; i < _char_list.size(); i++) - { - if (_char_list[i] == " ") { - _SPACE_ID = i; - _char_map[' '] = i; - } else if(_char_list[i].size() == 1){ - _char_map[_char_list[i][0]] = i; - } + _char_list = char_list; + _char_map.clear(); + + for (unsigned int i = 0; i < _char_list.size(); i++) { + if (_char_list[i] == " ") { + _SPACE_ID = i; + _char_map[' '] = i; + } else if (_char_list[i].size() == 1) { + _char_map[_char_list[i][0]] = i; } + } } std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; - - for (int order = 0; order < _max_order; order++) { - std::vector prefix_vec; - - if (_is_character_based) { - new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); - current_node = new_node->_parent; // Skipping spaces - } - - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); - - if (new_node->_character == -1) { - // No more spaces, but still need order - for (int i = 0; i < _max_order - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; - } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; -} - -void Scorer::fill_dictionary(bool add_space) { + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; - fst::StdVectorFst dictionary; - // First reverse char_list so ints can be accessed by chars - std::unordered_map char_map; - for (unsigned int i = 0; i < _char_list.size(); i++) { - char_map[_char_list[i]] = i; - } + for (int order = 0; order < _max_order; order++) { + std::vector prefix_vec; - // For each unigram convert to ints and put in trie - int vocab_size = 0; - for (const auto& word : _vocabulary) { - bool added = add_word_to_dictionary(word, - char_map, - add_space, - _SPACE_ID, - &dictionary); - vocab_size += added ? 1 : 0; + if (_is_character_based) { + new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID, 1); + current_node = new_node; + } else { + new_node = current_node->get_path_vec(prefix_vec, _SPACE_ID); + current_node = new_node->parent; // Skipping spaces } - std::cerr << "Vocab Size " << vocab_size << std::endl; - - // Simplify FST - - // This gets rid of "epsilon" transitions in the FST. - // These are transitions that don't require a string input to be taken. - // Getting rid of them is necessary to make the FST determinisitc, but - // can greatly increase the size of the FST - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); - // This makes the FST deterministic, meaning for any string input there's - // only one possible state the FST could be in. It is assumed our - // dictionary is deterministic when using it. - // (lest we'd have to check for multiple transitions at each state) - fst::Determinize(dictionary, new_dict); - - // Finds the simplest equivalent fst. This is unnecessary but decreases - // memory usage of the dictionary - fst::Minimize(new_dict); - this->dictionary = new_dict; + if (new_node->character == -1) { + // No more spaces, but still need order + for (int i = 0; i < _max_order - order - 1; i++) { + ngram.push_back(START_TOKEN); + } + break; + } + } + std::reverse(ngram.begin(), ngram.end()); + return ngram; +} +void Scorer::fill_dictionary(bool add_space) { + fst::StdVectorFst dictionary; + // First reverse char_list so ints can be accessed by chars + std::unordered_map char_map; + for (unsigned int i = 0; i < _char_list.size(); i++) { + char_map[_char_list[i]] = i; + } + + // For each unigram convert to ints and put in trie + int vocab_size = 0; + for (const auto& word : _vocabulary) { + bool added = add_word_to_dictionary( + word, char_map, add_space, _SPACE_ID, &dictionary); + vocab_size += added ? 1 : 0; + } + + std::cerr << "Vocab Size " << vocab_size << std::endl; + + // Simplify FST + + // This gets rid of "epsilon" transitions in the FST. + // These are transitions that don't require a string input to be taken. + // Getting rid of them is necessary to make the FST determinisitc, but + // can greatly increase the size of the FST + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + // This makes the FST deterministic, meaning for any string input there's + // only one possible state the FST could be in. It is assumed our + // dictionary is deterministic when using it. + // (lest we'd have to check for multiple transitions at each state) + fst::Determinize(dictionary, new_dict); + + // Finds the simplest equivalent fst. This is unnecessary but decreases + // memory usage of the dictionary + fst::Minimize(new_dict); + this->dictionary = new_dict; } diff --git a/deploy/scorer.h b/deploy/scorer.h index e3d61a71..0c78b987 100644 --- a/deploy/scorer.h +++ b/deploy/scorer.h @@ -1,31 +1,31 @@ #ifndef SCORER_H_ #define SCORER_H_ -#include #include -#include +#include #include +#include #include "lm/enumerate_vocab.hh" -#include "lm/word_index.hh" #include "lm/virtual_interface.hh" -#include "util/string_piece.hh" +#include "lm/word_index.hh" #include "path_trie.h" +#include "util/string_piece.hh" const double OOV_SCORE = -1000.0; const std::string START_TOKEN = ""; const std::string UNK_TOKEN = ""; const std::string END_TOKEN = ""; - // Implement a callback to retrive string vocabulary. +// Implement a callback to retrive string vocabulary. class RetriveStrEnumerateVocab : public lm::EnumerateVocab { public: - RetriveStrEnumerateVocab() {} + RetriveStrEnumerateVocab() {} - void Add(lm::WordIndex index, const StringPiece& str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } + void Add(lm::WordIndex index, const StringPiece& str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } - std::vector vocabulary; + std::vector vocabulary; }; // External scorer to query languange score for n-gram or sentence. @@ -33,59 +33,59 @@ public: // Scorer scorer(alpha, beta, "path_of_language_model"); // scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); // scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); -class Scorer{ +class Scorer { public: - Scorer(double alpha, double beta, const std::string& lm_path); - ~Scorer(); + Scorer(double alpha, double beta, const std::string& lm_path); + ~Scorer(); - double get_log_cond_prob(const std::vector& words); + double get_log_cond_prob(const std::vector& words); - double get_sent_log_prob(const std::vector& words); + double get_sent_log_prob(const std::vector& words); - size_t get_max_order() { return _max_order; } + size_t get_max_order() { return _max_order; } - bool is_char_map_empty() {return _char_map.size() == 0; } + bool is_char_map_empty() { return _char_map.size() == 0; } - bool is_character_based() { return _is_character_based; } + bool is_character_based() { return _is_character_based; } - // reset params alpha & beta - void reset_params(float alpha, float beta); + // reset params alpha & beta + void reset_params(float alpha, float beta); - // make ngram - std::vector make_ngram(PathTrie* prefix); + // make ngram + std::vector make_ngram(PathTrie* prefix); - // fill dictionary for fst - void fill_dictionary(bool add_space); + // fill dictionary for fst + void fill_dictionary(bool add_space); - // set char map - void set_char_map(std::vector char_list); + // set char map + void set_char_map(std::vector char_list); - std::vector split_labels(const std::vector &labels); + std::vector split_labels(const std::vector& labels); - // expose to decoder - double alpha; - double beta; + // expose to decoder + double alpha; + double beta; - // fst dictionary - void* dictionary; + // fst dictionary + void* dictionary; protected: - void load_LM(const char* filename); + void load_LM(const char* filename); - double get_log_prob(const std::vector& words); + double get_log_prob(const std::vector& words); - std::string vec2str(const std::vector &input); + std::string vec2str(const std::vector& input); private: - void* _language_model; - bool _is_character_based; - size_t _max_order; + void* _language_model; + bool _is_character_based; + size_t _max_order; - int _SPACE_ID; - std::vector _char_list; - std::unordered_map _char_map; + int _SPACE_ID; + std::vector _char_list; + std::unordered_map _char_map; - std::vector _vocabulary; + std::vector _vocabulary; }; -#endif // SCORER_H_ +#endif // SCORER_H_