simple ctc prefix beam search compile ok

pull/2524/head
Hui Zhang 3 years ago
parent bc1b6c2e7c
commit 3c3aa6b594

@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <condition_variable>
#include <cstring>
@ -35,6 +36,7 @@
#include <stdexcept>
#include <string>
#include <thread>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>

@ -25,8 +25,7 @@ namespace ppspeech {
void operator=(const TypeName&) = delete
#endif
constexpr float kFloatMax = std::numeric_limits<float>::max();
// kSpaceSymbol in UTF-8 is: ▁
const std::string kSpaceSymbol = "\xe2\x96\x81";
} // namespace ppspeech

@ -2,10 +2,11 @@ project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_beam_search_decoder.cc
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
)

@ -26,9 +26,7 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id_(opts.blank),
space_id_(-1),
num_frame_decoded_(0),
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
@ -43,7 +41,7 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
CHECK(blank_id_==0);
CHECK(opts_.blank==0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin();
@ -167,7 +165,7 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
continue;
}
min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[blank_id_]) -
std::log(prob[opts_.blank]) -
std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes_ == beam_size);
@ -195,9 +193,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove();
}
} // if
} // end if
num_frame_decoded_++;
} // for probs_seq
} // end for probs_seq
}
int32 CTCBeamSearch::SearchOneChar(
@ -215,7 +213,7 @@ int32 CTCBeamSearch::SearchOneChar(
break;
}
if (c == blank_id_) {
if (c == opts_.blank) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;

@ -66,11 +66,10 @@ class CTCBeamSearch : public DecoderInterface {
CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id_;
int space_id_;
std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};

@ -11,3 +11,307 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "utils/math.h"
#ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
#endif
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts) {
InitDecoder();
}
void CTCPrefixBeamSearch::InitDecoder() {
num_frame_decoded_ = 0;
cur_hyps_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
outputs_.clear();
abs_time_step_ = 0;
// empty hyp with Score
std::vector<int> empty;
PrefixScore prefix_score;
prefix_score.b = 0.0f; // log(1)
prefix_score.nb = -kBaseFloatMax; // log(0)
prefix_score.v_b = 0.0f; // log(1)
prefix_score.v_nb = 0.0f; // log(1)
cur_hyps_[empty] = prefix_score;
outputs_.emplace_back(empty);
hypotheses_.emplace_back(empty);
likelihood_.emplace_back(prefix_score.TotalScore());
times_.emplace_back(empty);
}
void CTCPrefixBeamSearch::Reset() {
InitDecoder();
}
void CTCPrefixBeamSearch::Decode(
std::shared_ptr<kaldi::DecodableInterface> decodable) {
return;
}
int32 CTCPrefixBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
void CTCPrefixBeamSearch::UpdateOutputs(
const std::pair<std::vector<int>, PrefixScore>& prefix) {
const std::vector<int>& input = prefix.first;
// const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
// const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
std::vector<int> output;
int s = 0;
int e = 0;
for (int i = 0; i < input.size(); ++i) {
// if (s < start_boundaries.size() && i == start_boundaries[s]){
// // <context>
// output.emplace_back(context_graph_->start_tag_id());
// ++s;
// }
output.emplace_back(input[i]);
// if (e < end_boundaries.size() && i == end_boundaries[e]){
// // </context>
// output.emplace_back(context_graph_->end_tag_id());
// ++e;
// }
}
outputs_.emplace_back(output);
}
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
std::vector<kaldi::BaseFloat> frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
}
}
static bool PrefixScoreCompare(
const std::pair<std::vector<int>, PrefixScore>& a,
const std::pair<std::vector<int>, PrefixScore>& b) {
// log domain
return a.second.TotalScore() > b.second.TotalScore();
}
void CTCPrefixBeamSearch::AdvanceDecoding(const std::vector<std::vector<float>>& logp) {
#ifdef USE_PROFILING
RecordEvent event(
"CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1);
#endif
if (logp.size() == 0) return;
int first_beam_size =
std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) {
const std::vector<float>& logp_t = logp[t];
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash> next_hyps;
// 1. first beam prune, only select topk candidates
std::vector<float> topk_score;
std::vector<int32_t> topk_index;
TopK(logp_t, first_beam_size, &topk_score, &topk_index);
// 2. token passing
for (int i = 0; i < topk_index.size(); ++i) {
int id = topk_index[i];
auto prob = topk_score[i];
for (const auto& it : cur_hyps_) {
const std::vector<int>& prefix = it.first;
const PrefixScore& prefix_score = it.second;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert
// PrefixScore(-inf, -inf) by default, since the default constructor
// of PrefixScore will set fields b(blank ending Score) and
// nb(none blank ending Score) to -inf, respectively.
if (id == opts_.blank) {
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a, prefix not
// change
PrefixScore& next_score = next_hyps[prefix];
next_score.b = LogSumExp(next_score.b, prefix_score.Score() + prob);
// timestamp, blank is slince, not effact timestamp
next_score.v_b = prefix_score.ViterbiScore() + prob;
next_score.times_b = prefix_score.Times();
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score.has_context) {
next_score.CopyContext(prefix_score);
next_score.has_context = true;
}
} else if (!prefix.empty() && id == prefix.back()) {
// case 1: *a + a => *a, prefix not changed
PrefixScore& next_score1 = next_hyps[prefix];
next_score1.nb = LogSumExp(next_score1.nb, prefix_score.nb + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score1.v_nb < prefix_score.v_nb + prob) {
// compute viterbi Score
next_score1.v_nb = prefix_score.v_nb + prob;
if (next_score1.cur_token_prob < prob) {
// store max token prob
next_score1.cur_token_prob = prob;
// update this timestamp as token appeared here.
next_score1.times_nb = prefix_score.times_nb;
assert(next_score1.times_nb.size() > 0);
next_score1.times_nb.back() = abs_time_step_;
}
}
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score1.has_context) {
next_score1.CopyContext(prefix_score);
next_score1.has_context = true;
}
// case 2: *a<blank> + a => *aa, prefix changed.
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score2 = next_hyps[new_prefix];
next_score2.nb = LogSumExp(next_score2.nb, prefix_score.b + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score2.v_nb < prefix_score.v_b + prob) {
// compute viterbi Score
next_score2.v_nb = prefix_score.v_b + prob;
// new token added
next_score2.cur_token_prob = prob;
next_score2.times_nb = prefix_score.times_b;
next_score2.times_nb.emplace_back(abs_time_step_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score2.has_context) {
next_score2.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score2.has_context = true;
}
} else {
// id != prefix.back()
// case 3: *a + b => *ab, *a<blank> +b => *ab
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score = next_hyps[new_prefix];
next_score.nb = LogSumExp(next_score.nb, prefix_score.Score() + prob);
// timetamp, non-blank symbol effact timestamp
if (next_score.v_nb < prefix_score.ViterbiScore() + prob) {
next_score.v_nb = prefix_score.ViterbiScore() + prob;
next_score.cur_token_prob = prob;
next_score.times_nb = prefix_score.Times();
next_score.times_nb.emplace_back(abs_time_step_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score.has_context) {
next_score.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score.has_context = true;
}
}
} // end for (const auto& it : cur_hyps_)
} // end for (int i = 0; i < topk_index.size(); ++i)
// 3. second beam prune, only keep top n best paths
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(next_hyps.begin(),
next_hyps.end());
int second_beam_size =
std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
std::nth_element(arr.begin(),
arr.begin() + second_beam_size,
arr.end(),
PrefixScoreCompare);
arr.resize(second_beam_size);
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// 4. update cur_hyps by next_hyps, and get new result
UpdateHypotheses(arr);
num_frame_decoded_++;
} // end for (int t = 0; t < logp.size(); ++t, ++abs_time_step_)
}
void CTCPrefixBeamSearch::UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hyps) {
cur_hyps_.clear();
outputs_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
for (auto& item : hyps) {
cur_hyps_[item.first] = item.second;
UpdateOutputs(item);
hypotheses_.emplace_back(std::move(item.first));
likelihood_.emplace_back(item.second.TotalScore());
viterbi_likelihood_.emplace_back(item.second.ViterbiScore());
times_.emplace_back(item.second.Times());
}
}
void CTCPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); }
void CTCPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return;
assert(hypotheses_.size() == cur_hyps_.size());
assert(hypotheses_.size() == likelihood_.size());
// We should backoff the context Score/state when the context is
// not fully matched at the last time.
for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_score != 0) {
// prefix_score.UpdateContext(context_graph_, prefix_score, 0,
// prefix.size());
}
}
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
cur_hyps_.end());
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// Update cur_hyps_ and get new result
UpdateHypotheses(arr);
}
} // namespace ppspeech

@ -18,10 +18,8 @@
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderInterface {
public:
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
@ -29,36 +27,74 @@ class CTCPrefixBeamSearch : public DecoderInterface {
void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
const std::vector<float>& ViterbiLikelihood() const {
return viterbi_likelihood_;
}
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
const std::vector<float>& Likelihood() const { return likelihood_; }
const std::vector<std::vector<int>>& Times() const { return times_; }
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& logp);
void FinalizeSearch();
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
void UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& prefix);
void UpdateFinalContext();
private:
CTCBeamSearchOptions opts_;
size_t blank_id_;
int num_frame_decoded_;
int abs_time_step_ = 0;
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_;
// n-best list and corresponding likelihood, in sorted order
std::vector<std::vector<int>> hypotheses_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::vector<float> viterbi_likelihood_;
// Outputs contain the hypotheses_ and tags lik: <context> and </context>
std::vector<std::vector<int>> outputs_;
std::shared_ptr<ContextGraph> context_graph_ = nullptr;
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
};
} // namespace basr
} // namespace ppspeech

@ -20,35 +20,55 @@
namespace ppspeech {
struct PrefxiScore {
class ContextGraph;
struct PrefixScore {
// decoding, unit in log scale
float b = -kFloatMax; // blank ending score
float nb = -kFloatMax; // none-blank ending score
float b = -kBaseFloatMax; // blank ending score
float nb = -kBaseFloatMax; // none-blank ending score
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// timestamp, unit in log sclae
float v_b = -kFloatMax; // viterbi blank ending score
float v_nb = -kFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
float v_b = -kBaseFloatMax; // viterbi blank ending score
float v_nb = -kBaseFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kBaseFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
}
// context state
bool has_context = false;
int context_state = 0;
float context_score = 0;
std::vector<int> start_boundaries;
std::vector<int> end_boundaries;
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// decodign score with context bias
float TotalScore() const { return Score() + context_score; }
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
void CopyContext(const PrefixScore& prefix_score) {
context_state = prefix_score.context_state;
context_score = prefix_score.context_score;
start_boundaries = prefix_score.start_boundaries;
end_boundaries = prefix_score.end_boundaries;
}
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
void UpdateContext(const std::shared_ptr<ContextGraph>& constext_graph,
const PrefixScore& prefix_score,
int word_id,
int prefix_len) {
CHECK(false);
}
};

@ -63,8 +63,6 @@ class TLGDecoder : public DecoderInterface {
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 num_frame_decoded_;
};

@ -31,7 +31,6 @@ class DecoderInterface {
virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0;
@ -46,7 +45,7 @@ class DecoderInterface {
// std::vector<std::string>& nbest_words);
private:
protected:
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number

@ -28,8 +28,8 @@ namespace ppspeech {
// Sum in log scale
float LogSumExp(float x, float y) {
if (x <= -kFloatMax) return y;
if (y <= -kFloatMax) return x;
if (x <= -kBaseFloatMax) return y;
if (y <= -kBaseFloatMax) return x;
float max = std::max(x, y);
return max + std::log(std::exp(x - max) + std::exp(y - max));
}

Loading…
Cancel
Save