|
|
|
@ -18,7 +18,7 @@ std::string ctc_greedy_decoder(
|
|
|
|
|
const std::vector<std::string> &vocabulary) {
|
|
|
|
|
// dimension check
|
|
|
|
|
size_t num_time_steps = probs_seq.size();
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; i++) {
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
|
|
|
|
VALID_CHECK_EQ(probs_seq[i].size(),
|
|
|
|
|
vocabulary.size() + 1,
|
|
|
|
|
"The shape of probs_seq does not match with "
|
|
|
|
@ -28,7 +28,7 @@ std::string ctc_greedy_decoder(
|
|
|
|
|
size_t blank_id = vocabulary.size();
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> max_idx_vec;
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; i++) {
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
|
|
|
|
double max_prob = 0.0;
|
|
|
|
|
size_t max_idx = 0;
|
|
|
|
|
for (size_t j = 0; j < probs_seq[i].size(); j++) {
|
|
|
|
@ -41,14 +41,14 @@ std::string ctc_greedy_decoder(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> idx_vec;
|
|
|
|
|
for (size_t i = 0; i < max_idx_vec.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < max_idx_vec.size(); ++i) {
|
|
|
|
|
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
|
|
|
|
idx_vec.push_back(max_idx_vec[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string best_path_result;
|
|
|
|
|
for (size_t i = 0; i < idx_vec.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < idx_vec.size(); ++i) {
|
|
|
|
|
if (idx_vec[i] != blank_id) {
|
|
|
|
|
best_path_result += vocabulary[idx_vec[i]];
|
|
|
|
|
}
|
|
|
|
@ -65,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
Scorer *ext_scorer) {
|
|
|
|
|
// dimension check
|
|
|
|
|
size_t num_time_steps = probs_seq.size();
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; i++) {
|
|
|
|
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
|
|
|
|
VALID_CHECK_EQ(probs_seq[i].size(),
|
|
|
|
|
vocabulary.size() + 1,
|
|
|
|
|
"The shape of probs_seq does not match with "
|
|
|
|
@ -111,7 +111,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
|
|
|
|
std::vector<double> prob = probs_seq[time_step];
|
|
|
|
|
std::vector<std::pair<int, double>> prob_idx;
|
|
|
|
|
for (size_t i = 0; i < prob.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < prob.size(); ++i) {
|
|
|
|
|
prob_idx.push_back(std::pair<int, double>(i, prob[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -134,7 +134,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
if (cutoff_prob < 1.0) {
|
|
|
|
|
double cum_prob = 0.0;
|
|
|
|
|
cutoff_len = 0;
|
|
|
|
|
for (size_t i = 0; i < prob_idx.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < prob_idx.size(); ++i) {
|
|
|
|
|
cum_prob += prob_idx[i].second;
|
|
|
|
|
cutoff_len += 1;
|
|
|
|
|
if (cum_prob >= cutoff_prob) break;
|
|
|
|
@ -145,7 +145,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::pair<size_t, float>> log_prob_idx;
|
|
|
|
|
for (size_t i = 0; i < cutoff_len; i++) {
|
|
|
|
|
for (size_t i = 0; i < cutoff_len; ++i) {
|
|
|
|
|
log_prob_idx.push_back(std::pair<int, float>(
|
|
|
|
|
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
|
|
|
|
}
|
|
|
|
@ -155,7 +155,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
auto c = log_prob_idx[index].first;
|
|
|
|
|
float log_prob_c = log_prob_idx[index].second;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < prefixes.size() && i < beam_size; i++) {
|
|
|
|
|
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
|
|
|
|
auto prefix = prefixes[i];
|
|
|
|
|
|
|
|
|
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
|
|
|
@ -222,14 +222,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
prefixes.end(),
|
|
|
|
|
prefix_compare);
|
|
|
|
|
|
|
|
|
|
for (size_t i = beam_size; i < prefixes.size(); i++) {
|
|
|
|
|
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
|
|
|
|
prefixes[i]->remove();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // end of loop over time
|
|
|
|
|
|
|
|
|
|
// compute aproximate ctc score as the return score
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
|
double approx_ctc = prefixes[i]->score;
|
|
|
|
|
|
|
|
|
|
if (ext_scorer != nullptr) {
|
|
|
|
@ -249,14 +249,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
|
|
|
|
// allow for the post processing
|
|
|
|
|
std::vector<PathTrie *> space_prefixes;
|
|
|
|
|
if (space_prefixes.empty()) {
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
|
space_prefixes.push_back(prefixes[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
|
|
|
|
std::vector<std::pair<double, std::string>> output_vecs;
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
|
|
|
|
std::vector<int> output;
|
|
|
|
|
space_prefixes[i]->get_path_vec(output);
|
|
|
|
|
// convert index to string
|
|
|
|
@ -301,7 +301,7 @@ ctc_beam_search_decoder_batch(
|
|
|
|
|
|
|
|
|
|
// enqueue the tasks of decoding
|
|
|
|
|
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
|
|
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
|
|
|
|
probs_split[i],
|
|
|
|
|
beam_size,
|
|
|
|
@ -313,7 +313,7 @@ ctc_beam_search_decoder_batch(
|
|
|
|
|
|
|
|
|
|
// get decoding results
|
|
|
|
|
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
|
|
|
|
for (size_t i = 0; i < batch_size; i++) {
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
batch_results.emplace_back(res[i].get());
|
|
|
|
|
}
|
|
|
|
|
return batch_results;
|
|
|
|
|