format code

pull/578/head
Hui Zhang 4 years ago
parent e969a8ec80
commit f842c79a5f

@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.push_back(&root); prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary); auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr); root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT); auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
bool full_beam = false; bool full_beam = false;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort( std::sort(prefixes.begin(),
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score + min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta); std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size); full_beam = (num_prefixes == beam_size);
} }
@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
// blank // blank
if (c == blank_id) { if (c == blank_id) {
prefix->log_prob_b_cur = prefix->log_prob_b_cur = log_sum_exp(
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue; continue;
} }
// repeated character // repeated character
if (c == prefix->character) { if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp( prefix->log_prob_nb_cur =
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
} }
// get new prefix // get new prefix
auto prefix_new = prefix->get_path_trie(c); auto prefix_new = prefix->get_path_trie(c);
@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score); ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score; log_p += score;
log_p += ext_scorer->beta; log_p += ext_scorer->beta;
} }
@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (!prefix->is_empty() && prefix->character != space_id) { if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix); std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta; score += ext_scorer->beta;
prefix->score += score; prefix->score += score;
} }
@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
} }
size_t num_prefixes = std::min(prefixes.size(), beam_size); size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the // compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable. // return order of decoding result. To delete when decoder gets stable.
@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// remove word insert // remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight: // remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
} }
prefixes[i]->approx_ctc = approx_ctc; prefixes[i]->approx_ctc = approx_ctc;
} }

@ -29,15 +29,17 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
// pruning of vacobulary // pruning of vacobulary
size_t cutoff_len = prob_step.size(); size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort( std::sort(prob_idx.begin(),
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>); prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) { for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
} }
} }
prob_idx = std::vector<std::pair<int, double>>( prob_idx = std::vector<std::pair<int, double>>(
@ -74,8 +76,8 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
for (size_t j = 0; j < output.size(); j++) { for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]]; output_str += vocabulary[output[j]];
} }
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc, std::pair<double, std::string> output_pair(
output_str); -space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair); output_vecs.emplace_back(output_pair);
} }

@ -134,7 +134,8 @@ void PathTrie::remove() {
if (children_.size() == 0) { if (children_.size() == 0) {
auto child = parent->children_.begin(); auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end(); for (child = parent->children_.begin();
child != parent->children_.end();
++child) { ++child) {
if (child->first == character) { if (child->first == character) {
parent->children_.erase(child); parent->children_.erase(child);

@ -38,7 +38,8 @@ public:
PathTrie* get_path_vec(std::vector<int>& output); PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel // get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output, PathTrie* get_path_vec(
std::vector<int>& output,
int stop, int stop,
size_t max_steps = std::numeric_limits<size_t>::max()); size_t max_steps = std::numeric_limits<size_t>::max());

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys
import pstats
import cProfile import cProfile
from io import StringIO
import getopt import getopt
import os import os
from os.path import dirname, join import pstats
import sys
from io import StringIO
from os.path import dirname
from os.path import join
import mmseg import mmseg

Loading…
Cancel
Save