format code

pull/578/head
Hui Zhang 4 years ago
parent e969a8ec80
commit f842c79a5f

@ -65,7 +65,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.push_back(&root);
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
@ -80,10 +81,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
@ -101,14 +104,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
@ -137,7 +141,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
@ -171,7 +176,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
@ -179,7 +185,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
@ -193,7 +200,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}

@ -29,15 +29,17 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
std::sort(prob_idx.begin(),
prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n)
break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
@ -74,8 +76,8 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
for (size_t j = 0; j < output.size(); j++) {
output_str += vocabulary[output[j]];
}
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
output_str);
std::pair<double, std::string> output_pair(
-space_prefixes[i]->approx_ctc, output_str);
output_vecs.emplace_back(output_pair);
}

@ -134,7 +134,8 @@ void PathTrie::remove() {
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
for (child = parent->children_.begin();
child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);

@ -38,7 +38,8 @@ public:
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output,
PathTrie* get_path_vec(
std::vector<int>& output,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());

@ -1,12 +1,12 @@
#!/usr/bin/env python3
import sys
import pstats
import cProfile
from io import StringIO
import getopt
import os
from os.path import dirname, join
import pstats
import sys
from io import StringIO
from os.path import dirname
from os.path import join
import mmseg

Loading…
Cancel
Save