|
|
@ -2,33 +2,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
#include "base/basic_types.h"
|
|
|
|
#include "base/basic_types.h"
|
|
|
|
#include "decoder/ctc_decoders/decoder_utils.h"
|
|
|
|
#include "decoder/ctc_decoders/decoder_utils.h"
|
|
|
|
|
|
|
|
#include "utils/file_utils.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace ppspeech {
|
|
|
|
namespace ppspeech {
|
|
|
|
|
|
|
|
|
|
|
|
using std::vector;
|
|
|
|
using std::vector;
|
|
|
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
|
|
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
|
|
|
|
|
|
|
|
|
|
|
CTCBeamSearch::CTCBeamSearch(std::shared_ptr<CTCBeamSearchOptions> opts) :
|
|
|
|
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) :
|
|
|
|
opts_(opts),
|
|
|
|
opts_(opts),
|
|
|
|
vocabulary_(nullptr),
|
|
|
|
|
|
|
|
init_ext_scorer_(nullptr),
|
|
|
|
init_ext_scorer_(nullptr),
|
|
|
|
blank_id(-1),
|
|
|
|
blank_id(-1),
|
|
|
|
space_id(-1),
|
|
|
|
space_id(-1),
|
|
|
|
num_frame_decoded(0),
|
|
|
|
num_frame_decoded_(0),
|
|
|
|
root(nullptr) {
|
|
|
|
root(nullptr) {
|
|
|
|
|
|
|
|
|
|
|
|
LOG(INFO) << "dict path: " << opts_.dict_file;
|
|
|
|
LOG(INFO) << "dict path: " << opts_.dict_file;
|
|
|
|
vocabulary_ = std::make_shared<vector<string>>();
|
|
|
|
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
|
|
|
|
if (!basr::ReadDictToVector(opts_.dict_file, *vocabulary_)) {
|
|
|
|
|
|
|
|
LOG(INFO) << "load the dict failed";
|
|
|
|
LOG(INFO) << "load the dict failed";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_->size();
|
|
|
|
LOG(INFO) << "read the vocabulary success, dict size: " << vocabulary_.size();
|
|
|
|
|
|
|
|
|
|
|
|
LOG(INFO) << "language model path: " << opts_.lm_path;
|
|
|
|
LOG(INFO) << "language model path: " << opts_.lm_path;
|
|
|
|
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
|
|
|
|
init_ext_scorer_ = std::make_shared<Scorer>(opts_.alpha,
|
|
|
|
opts_.beta,
|
|
|
|
opts_.beta,
|
|
|
|
opts_.lm_path,
|
|
|
|
opts_.lm_path,
|
|
|
|
*vocabulary_);
|
|
|
|
vocabulary_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::Reset() {
|
|
|
|
void CTCBeamSearch::Reset() {
|
|
|
@ -39,11 +38,11 @@ void CTCBeamSearch::Reset() {
|
|
|
|
void CTCBeamSearch::InitDecoder() {
|
|
|
|
void CTCBeamSearch::InitDecoder() {
|
|
|
|
|
|
|
|
|
|
|
|
blank_id = 0;
|
|
|
|
blank_id = 0;
|
|
|
|
auto it = std::find(vocabulary_->begin(), vocabulary_->end(), " ");
|
|
|
|
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
|
|
|
|
|
|
|
|
|
|
|
|
space_id = it - vocabulary_->begin();
|
|
|
|
space_id = it - vocabulary_.begin();
|
|
|
|
// if no space in vocabulary
|
|
|
|
// if no space in vocabulary
|
|
|
|
if ((size_t)space_id >= vocabulary_->size()) {
|
|
|
|
if ((size_t)space_id >= vocabulary_.size()) {
|
|
|
|
space_id = -2;
|
|
|
|
space_id = -2;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -63,19 +62,24 @@ void CTCBeamSearch::InitDecoder() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::Decode(std::shared_ptr<kaldi::DecodableInterface> decodable) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int32 CTCBeamSearch::NumFrameDecoded() {
|
|
|
|
int32 CTCBeamSearch::NumFrameDecoded() {
|
|
|
|
return num_frame_decoded_;
|
|
|
|
return num_frame_decoded_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// todo rename, refactor
|
|
|
|
// todo rename, refactor
|
|
|
|
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable, int max_frames) {
|
|
|
|
void CTCBeamSearch::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable,
|
|
|
|
|
|
|
|
int max_frames) {
|
|
|
|
while (max_frames > 0) {
|
|
|
|
while (max_frames > 0) {
|
|
|
|
vector<vector<BaseFloat>> likelihood;
|
|
|
|
vector<vector<BaseFloat>> likelihood;
|
|
|
|
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
|
|
|
|
if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
|
|
|
|
likelihood.push_back(decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
|
|
|
|
AdvanceDecoding(result);
|
|
|
|
AdvanceDecoding(likelihood);
|
|
|
|
max_frames--;
|
|
|
|
max_frames--;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -91,32 +95,21 @@ void CTCBeamSearch::ResetPrefixes() {
|
|
|
|
|
|
|
|
|
|
|
|
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
|
|
|
|
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>&probs,
|
|
|
|
vector<string>& nbest_words) {
|
|
|
|
vector<string>& nbest_words) {
|
|
|
|
std::thread::id this_id = std::this_thread::get_id();
|
|
|
|
kaldi::Timer timer;
|
|
|
|
Timer timer;
|
|
|
|
|
|
|
|
vector<vector<double>> double_probs(probs.size(), vector<double>(probs[0].size(), 0));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int row = probs.size();
|
|
|
|
|
|
|
|
int col = probs[0].size();
|
|
|
|
|
|
|
|
for(int i = 0; i < row; i++) {
|
|
|
|
|
|
|
|
for (int j = 0; j < col; j++){
|
|
|
|
|
|
|
|
double_probs[i][j] = static_cast<double>(probs[i][j]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timer.Reset();
|
|
|
|
timer.Reset();
|
|
|
|
AdvanceDecoding(double_probs);
|
|
|
|
AdvanceDecoding(probs);
|
|
|
|
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
|
|
|
|
LOG(INFO) <<"ctc decoding elapsed time(s) " << static_cast<float>(timer.Elapsed()) / 1000.0f;
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
|
|
|
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
|
|
|
|
return get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
|
|
|
|
return get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
string CTCBeamSearch::GetBestPath() {
|
|
|
|
string CTCBeamSearch::GetBestPath() {
|
|
|
|
std::vector<std::pair<double, std::string>> result;
|
|
|
|
std::vector<std::pair<double, std::string>> result;
|
|
|
|
result = get_beam_search_result(prefixes, *vocabulary_, opts_.beam_size);
|
|
|
|
result = get_beam_search_result(prefixes, vocabulary_, opts_.beam_size);
|
|
|
|
return result[0]->second;
|
|
|
|
return result[0].second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
string CTCBeamSearch::GetFinalBestPath() {
|
|
|
|
string CTCBeamSearch::GetFinalBestPath() {
|
|
|
@ -125,12 +118,22 @@ string CTCBeamSearch::GetFinalBestPath() {
|
|
|
|
return GetBestPath();
|
|
|
|
return GetBestPath();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
|
|
|
|
void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
|
|
|
|
size_t num_time_steps = probs_seq.size();
|
|
|
|
size_t num_time_steps = probs.size();
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
double cutoff_prob = opts_.cutoff_prob;
|
|
|
|
double cutoff_prob = opts_.cutoff_prob;
|
|
|
|
size_t cutoff_top_n = opts_.cutoff_top_n;
|
|
|
|
size_t cutoff_top_n = opts_.cutoff_top_n;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector<vector<double>> probs_seq(probs.size(), vector<double>(probs[0].size(), 0));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int row = probs.size();
|
|
|
|
|
|
|
|
int col = probs[0].size();
|
|
|
|
|
|
|
|
for(int i = 0; i < row; i++) {
|
|
|
|
|
|
|
|
for (int j = 0; j < col; j++){
|
|
|
|
|
|
|
|
probs_seq[i][j] = static_cast<double>(probs[i][j]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
|
|
|
for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
|
|
|
|
const auto& prob = probs_seq[time_step];
|
|
|
|
const auto& prob = probs_seq[time_step];
|
|
|
|
|
|
|
|
|
|
|
@ -158,6 +161,7 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
|
|
|
|
size_t log_prob_idx_len = log_prob_idx.size();
|
|
|
|
size_t log_prob_idx_len = log_prob_idx.size();
|
|
|
|
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
|
|
|
for (size_t index = 0; index < log_prob_idx_len; index++) {
|
|
|
|
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
|
|
|
SearchOneChar(full_beam, log_prob_idx[index], min_cutoff);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
prefixes.clear();
|
|
|
|
prefixes.clear();
|
|
|
|
|
|
|
|
|
|
|
@ -177,9 +181,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<double>>& probs_seq) {
|
|
|
|
} // for probs_seq
|
|
|
|
} // for probs_seq
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int CTCBeamSearch::SearchOneChar(const bool& full_beam,
|
|
|
|
int32 CTCBeamSearch::SearchOneChar(const bool& full_beam,
|
|
|
|
const std::pair<size_t, float>& log_prob_idx,
|
|
|
|
const std::pair<size_t, BaseFloat>& log_prob_idx,
|
|
|
|
const float& min_cutoff) {
|
|
|
|
const BaseFloat& min_cutoff) {
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
size_t beam_size = opts_.beam_size;
|
|
|
|
const auto& c = log_prob_idx.first;
|
|
|
|
const auto& c = log_prob_idx.first;
|
|
|
|
const auto& log_prob_c = log_prob_idx.second;
|
|
|
|
const auto& log_prob_c = log_prob_idx.second;
|
|
|
|