|
|
@ -26,10 +26,10 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
|
|
|
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
|
|
|
|
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
|
|
|
|
: opts_(opts),
|
|
|
|
: opts_(opts),
|
|
|
|
init_ext_scorer_(nullptr),
|
|
|
|
init_ext_scorer_(nullptr),
|
|
|
|
blank_id(-1),
|
|
|
|
blank_id_(-1),
|
|
|
|
space_id(-1),
|
|
|
|
space_id_(-1),
|
|
|
|
num_frame_decoded_(0),
|
|
|
|
num_frame_decoded_(0),
|
|
|
|
root(nullptr) {
|
|
|
|
root_(nullptr) {
|
|
|
|
LOG(INFO) << "dict path: " << opts_.dict_file;
|
|
|
|
LOG(INFO) << "dict path: " << opts_.dict_file;
|
|
|
|
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
|
|
|
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
|
|
|
LOG(INFO) << "load the dict failed";
|
|
|
|
LOG(INFO) << "load the dict failed";
|
|
|
@ -40,37 +40,40 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
|
|
|
|
LOG(INFO) << "language model path: " << opts_.lm_path;
|
|
|
|
LOG(INFO) << "language model path: " << opts_.lm_path;
|
|
|
|
init_ext_scorer_ = std::make_shared<Scorer>(
|
|
|
|
init_ext_scorer_ = std::make_shared<Scorer>(
|
|
|
|
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
|
|
|
|
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::Reset() {
|
|
|
|
|
|
|
|
num_frame_decoded_ = 0;
|
|
|
|
|
|
|
|
ResetPrefixes();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::InitDecoder() {
|
|
|
|
blank_id_ = 0;
|
|
|
|
blank_id = 0;
|
|
|
|
|
|
|
|
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
|
|
|
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
|
|
|
|
|
|
|
|
|
|
|
space_id = it - vocabulary_.begin();
|
|
|
|
space_id_ = it - vocabulary_.begin();
|
|
|
|
// if no space in vocabulary
|
|
|
|
// if no space in vocabulary
|
|
|
|
if ((size_t)space_id >= vocabulary_.size()) {
|
|
|
|
if ((size_t)space_id_ >= vocabulary_.size()) {
|
|
|
|
space_id = -2;
|
|
|
|
space_id_ = -2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::Reset() {
|
|
|
|
|
|
|
|
//num_frame_decoded_ = 0;
|
|
|
|
|
|
|
|
//ResetPrefixes();
|
|
|
|
|
|
|
|
InitDecoder();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ResetPrefixes();
|
|
|
|
void CTCBeamSearch::InitDecoder() {
|
|
|
|
|
|
|
|
num_frame_decoded_ = 0;
|
|
|
|
|
|
|
|
//ResetPrefixes();
|
|
|
|
|
|
|
|
prefixes_.clear();
|
|
|
|
|
|
|
|
|
|
|
|
root = std::make_shared<PathTrie>();
|
|
|
|
root_ = std::make_shared<PathTrie>();
|
|
|
|
root->score = root->log_prob_b_prev = 0.0;
|
|
|
|
root_->score = root_->log_prob_b_prev = 0.0;
|
|
|
|
prefixes.push_back(root.get());
|
|
|
|
prefixes_.push_back(root_.get());
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
!init_ext_scorer_->is_character_based()) {
|
|
|
|
!init_ext_scorer_->is_character_based()) {
|
|
|
|
auto fst_dict =
|
|
|
|
auto fst_dict =
|
|
|
|
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
|
|
|
|
static_cast<fst::StdVectorFst*>(init_ext_scorer_->dictionary);
|
|
|
|
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
|
|
|
|
fst::StdVectorFst* dict_ptr = fst_dict->Copy(true);
|
|
|
|
root->set_dictionary(dict_ptr);
|
|
|
|
root_->set_dictionary(dict_ptr);
|
|
|
|
|
|
|
|
|
|
|
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
|
|
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
|
|
|
root->set_matcher(matcher);
|
|
|
|
root_->set_matcher(matcher);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -96,12 +99,13 @@ void CTCBeamSearch::AdvanceDecode(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::ResetPrefixes() {
|
|
|
|
void CTCBeamSearch::ResetPrefixes() {
|
|
|
|
for (size_t i = 0; i < prefixes.size(); i++) {
|
|
|
|
for (size_t i = 0; i < prefixes_.size(); i++) {
|
|
|
|
if (prefixes[i] != nullptr) {
|
|
|
|
if (prefixes_[i] != nullptr) {
|
|
|
|
delete prefixes[i];
|
|
|
|
delete prefixes_[i];
|
|
|
|
prefixes[i] = nullptr;
|
|
|
|
prefixes_[i] = nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
prefixes_.clear();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
|
|
|
|
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
|
|
|
@ -115,12 +119,12 @@ int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
|
|
|
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
|
|
|
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
|
|
|
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
string CTCBeamSearch::GetBestPath() {
|
|
|
|
string CTCBeamSearch::GetBestPath() {
|
|
|
|
std::vector<std::pair<double, std::string>> result;
|
|
|
|
std::vector<std::pair<double, std::string>> result;
|
|
|
|
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
|
|
|
result = get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size);
|
|
|
|
return result[0].second;
|
|
|
|
return result[0].second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -153,19 +157,19 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
|
|
|
float min_cutoff = -NUM_FLT_INF;
|
|
|
|
float min_cutoff = -NUM_FLT_INF;
|
|
|
|
bool full_beam = false;
|
|
|
|
bool full_beam = false;
|
|
|
|
if (init_ext_scorer_ != nullptr) {
|
|
|
|
if (init_ext_scorer_ != nullptr) {
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
|
|
|
std::sort(prefixes.begin(),
|
|
|
|
std::sort(prefixes_.begin(),
|
|
|
|
prefixes.begin() + num_prefixes,
|
|
|
|
prefixes_.begin() + num_prefixes_,
|
|
|
|
prefix_compare);
|
|
|
|
prefix_compare);
|
|
|
|
|
|
|
|
|
|
|
|
if (num_prefixes == 0) {
|
|
|
|
if (num_prefixes_ == 0) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
|
|
|
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
|
|
|
|
std::log(prob[blank_id]) -
|
|
|
|
std::log(prob[blank_id_]) -
|
|
|
|
std::max(0.0, init_ext_scorer_->beta);
|
|
|
|
std::max(0.0, init_ext_scorer_->beta);
|
|
|
|
|
|
|
|
|
|
|
|
full_beam = (num_prefixes == beam_size);
|
|
|
|
full_beam = (num_prefixes_ == beam_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
vector<std::pair<size_t, float>> log_prob_idx =
|
|
|
|
vector<std::pair<size_t, float>> log_prob_idx =
|
|
|
@ -177,18 +181,18 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
|
|
|
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
|
|
|
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
prefixes.clear();
|
|
|
|
prefixes_.clear();
|
|
|
|
|
|
|
|
|
|
|
|
// update log probs
|
|
|
|
// update log probs
|
|
|
|
root->iterate_to_vec(prefixes);
|
|
|
|
root_->iterate_to_vec(prefixes_);
|
|
|
|
// only preserve top beam_size prefixes
|
|
|
|
// only preserve top beam_size prefixes_
|
|
|
|
if (prefixes.size() >= beam_size) {
|
|
|
|
if (prefixes_.size() >= beam_size) {
|
|
|
|
std::nth_element(prefixes.begin(),
|
|
|
|
std::nth_element(prefixes_.begin(),
|
|
|
|
prefixes.begin() + beam_size,
|
|
|
|
prefixes_.begin() + beam_size,
|
|
|
|
prefixes.end(),
|
|
|
|
prefixes_.end(),
|
|
|
|
prefix_compare);
|
|
|
|
prefix_compare);
|
|
|
|
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
|
|
|
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
|
|
|
|
prefixes[i]->remove();
|
|
|
|
prefixes_[i]->remove();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // if
|
|
|
|
} // if
|
|
|
|
num_frame_decoded_++;
|
|
|
|
num_frame_decoded_++;
|
|
|
@ -202,15 +206,15 @@ int32 CTCBeamSearch::SearchOneChar(
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
const auto& c = log_prob_idx.first;
|
|
|
|
const auto& c = log_prob_idx.first;
|
|
|
|
const auto& log_prob_c = log_prob_idx.second;
|
|
|
|
const auto& log_prob_c = log_prob_idx.second;
|
|
|
|
size_t prefixes_len = std::min(prefixes.size(), beam_size);
|
|
|
|
size_t prefixes__len = std::min(prefixes_.size(), beam_size);
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < prefixes_len; ++i) {
|
|
|
|
for (size_t i = 0; i < prefixes__len; ++i) {
|
|
|
|
auto prefix = prefixes[i];
|
|
|
|
auto prefix = prefixes_[i];
|
|
|
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
|
|
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (c == blank_id) {
|
|
|
|
if (c == blank_id_) {
|
|
|
|
prefix->log_prob_b_cur =
|
|
|
|
prefix->log_prob_b_cur =
|
|
|
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
|
|
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -238,7 +242,7 @@ int32 CTCBeamSearch::SearchOneChar(
|
|
|
|
|
|
|
|
|
|
|
|
// language model scoring
|
|
|
|
// language model scoring
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
(c == space_id || init_ext_scorer_->is_character_based())) {
|
|
|
|
(c == space_id_ || init_ext_scorer_->is_character_based())) {
|
|
|
|
PathTrie* prefix_to_score = nullptr;
|
|
|
|
PathTrie* prefix_to_score = nullptr;
|
|
|
|
// skip scoring the space
|
|
|
|
// skip scoring the space
|
|
|
|
if (init_ext_scorer_->is_character_based()) {
|
|
|
|
if (init_ext_scorer_->is_character_based()) {
|
|
|
@ -266,17 +270,17 @@ int32 CTCBeamSearch::SearchOneChar(
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::CalculateApproxScore() {
|
|
|
|
void CTCBeamSearch::CalculateApproxScore() {
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
|
|
|
size_t num_prefixes_ = std::min(prefixes_.size(), beam_size);
|
|
|
|
std::sort(
|
|
|
|
std::sort(
|
|
|
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
|
|
|
prefixes_.begin(), prefixes_.begin() + num_prefixes_, prefix_compare);
|
|
|
|
|
|
|
|
|
|
|
|
// compute aproximate ctc score as the return score, without affecting the
|
|
|
|
// compute aproximate ctc score as the return score, without affecting the
|
|
|
|
// return order of decoding result. To delete when decoder gets stable.
|
|
|
|
// return order of decoding result. To delete when decoder gets stable.
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
|
|
|
double approx_ctc = prefixes[i]->score;
|
|
|
|
double approx_ctc = prefixes_[i]->score;
|
|
|
|
if (init_ext_scorer_ != nullptr) {
|
|
|
|
if (init_ext_scorer_ != nullptr) {
|
|
|
|
vector<int> output;
|
|
|
|
vector<int> output;
|
|
|
|
prefixes[i]->get_path_vec(output);
|
|
|
|
prefixes_[i]->get_path_vec(output);
|
|
|
|
auto prefix_length = output.size();
|
|
|
|
auto prefix_length = output.size();
|
|
|
|
auto words = init_ext_scorer_->split_labels(output);
|
|
|
|
auto words = init_ext_scorer_->split_labels(output);
|
|
|
|
// remove word insert
|
|
|
|
// remove word insert
|
|
|
@ -285,7 +289,7 @@ void CTCBeamSearch::CalculateApproxScore() {
|
|
|
|
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
|
|
|
|
approx_ctc -= (init_ext_scorer_->get_sent_log_prob(words)) *
|
|
|
|
init_ext_scorer_->alpha;
|
|
|
|
init_ext_scorer_->alpha;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
prefixes[i]->approx_ctc = approx_ctc;
|
|
|
|
prefixes_[i]->approx_ctc = approx_ctc;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -293,9 +297,9 @@ void CTCBeamSearch::LMRescore() {
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
if (init_ext_scorer_ != nullptr &&
|
|
|
|
!init_ext_scorer_->is_character_based()) {
|
|
|
|
!init_ext_scorer_->is_character_based()) {
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < beam_size && i < prefixes_.size(); ++i) {
|
|
|
|
auto prefix = prefixes[i];
|
|
|
|
auto prefix = prefixes_[i];
|
|
|
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
|
|
|
if (!prefix->is_empty() && prefix->character != space_id_) {
|
|
|
|
float score = 0.0;
|
|
|
|
float score = 0.0;
|
|
|
|
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
|
|
|
vector<string> ngram = init_ext_scorer_->make_ngram(prefix);
|
|
|
|
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
|
|
|
score = init_ext_scorer_->get_log_cond_prob(ngram) *
|
|
|
|