refactor ctc opts, extract decoder interface, add ctc beamsearch score

pull/2524/head
Hui Zhang 2 years ago
parent 5c8725e8cd
commit bc1b6c2e7c

@ -135,7 +135,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \

@ -133,7 +133,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \

@ -15,7 +15,7 @@ set(BINS
ctc_beam_search_decoder_main ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
recognizer_main recognizer_main
tlg_decoder_main ctc_tlg_decoder_main
) )
foreach(bin_name IN LISTS BINS) foreach(bin_name IN LISTS BINS)

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h" #include "decoder/ctc_decoders/decoder_utils.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
namespace ppspeech { namespace ppspeech {
@ -26,7 +26,7 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts), : opts_(opts),
init_ext_scorer_(nullptr), init_ext_scorer_(nullptr),
blank_id_(-1), blank_id_(opts.blank),
space_id_(-1), space_id_(-1),
num_frame_decoded_(0), num_frame_decoded_(0),
root_(nullptr) { root_(nullptr) {
@ -43,9 +43,9 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
} }
blank_id_ = 0; CHECK(blank_id_==0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin(); space_id_ = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) { if ((size_t)space_id_ >= vocabulary_.size()) {

@ -14,67 +14,48 @@
// used by deepspeech2 // used by deepspeech2
#include "base/common.h" #pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_decoders/path_trie.h" #include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h" #include "decoder/ctc_decoders/scorer.h"
#include "kaldi/decoder/decodable-itf.h" #include "decoder/decoder_itf.h"
#include "util/parse-options.h"
#pragma once
namespace ppspeech { namespace ppspeech {
struct CTCBeamSearchOptions { class CTCBeamSearch : public DecoderInterface {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {} ~CTCBeamSearch() {}
void InitDecoder(); void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable); void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private: private:
void ResetPrefixes(); void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam, int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff); const BaseFloat& min_cutoff);
@ -93,4 +74,4 @@ class CTCBeamSearch {
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace ppspeech

@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "util/parse-options.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",
&second_beam_size,
module + "second beam size.");
}
};
} // namespace ppspeech

@ -0,0 +1,13 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
class CTCPrefixBeamSearch : public DecoderInterface {
public:
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
size_t blank_id_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
};
} // namespace basr

@ -0,0 +1,68 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "utils/math.h"
namespace ppspeech {
struct PrefxiScore {
// decoding, unit in log scale
float b = -kFloatMax; // blank ending score
float nb = -kFloatMax; // none-blank ending score
// timestamp, unit in log sclae
float v_b = -kFloatMax; // viterbi blank ending score
float v_nb = -kFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
// context state
bool has_context = false;
int context_state = 0;
float context_score = 0;
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// decodign score with context bias
float TotalScore() const { return Score() + context_score; }
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
}
};
struct PrefixScoreHash {
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std::size_t operator()(const std::vector<int>& prefix) const {
std::size_t seed = prefix.size();
for (auto& i : prefix) {
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
} // namespace ppspeech

@ -22,24 +22,24 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst::SymbolTable::ReadText(opts.word_symbol_table)); fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0; num_frame_decoded_ = 0;
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::InitDecoder() {
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0; num_frame_decoded_ = 0;
} }
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) { while (!decodable->IsLastFrame(num_frame_decoded_)) {
AdvanceDecoding(decodable.get()); AdvanceDecoding(decodable.get());
} }
} }
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1); decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++; num_frame_decoded_++;
} }
void TLGDecoder::Reset() { void TLGDecoder::Reset() {
@ -48,7 +48,7 @@ void TLGDecoder::Reset() {
} }
std::string TLGDecoder::GetPartialResult() { std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
@ -68,7 +68,7 @@ std::string TLGDecoder::GetPartialResult() {
} }
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");

@ -14,8 +14,9 @@
#pragma once #pragma once
#include "base/basic_types.h" #include "base/common.h"
#include "kaldi/decoder/decodable-itf.h" #include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
@ -30,21 +31,31 @@ struct TLGDecoderOptions {
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
}; };
class TLGDecoder { class TLGDecoder : public DecoderInterface {
public: public:
explicit TLGDecoder(TLGDecoderOptions opts); explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder(); void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult();
void Decode(); void Decode();
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded(); int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
@ -53,7 +64,7 @@ class TLGDecoder {
std::shared_ptr<fst::Fst<fst::StdArc>> fst_; std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_; std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0. // the frame size which have decoded starts from 0.
int32 frame_decoded_size_; int32 num_frame_decoded_;
}; };

@ -14,13 +14,15 @@
// todo refactor, repalce with gtest // todo refactor, repalce with gtest
#include "base/flags.h" #include "base/common.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/ds2_nnet.h" #include "nnet/ds2_nnet.h"
#include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");

@ -0,0 +1,56 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
class DecoderInterface {
public:
virtual ~DecoderInterface() {}
virtual void InitDecoder() = 0;
virtual void Reset() = 0;
virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0;
// void Decode();
// std::string GetBestPath();
// std::vector<std::pair<double, std::string>> GetNBestPath();
// int NumFrameDecoded();
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
// std::vector<std::string>& nbest_words);
private:
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number
int32 num_frame_decoded_;
};
} // namespace ppspeech

@ -39,14 +39,14 @@ class Decodable : public kaldi::DecodableInterface {
// forward nnet with feats // forward nnet with feats
bool AdvanceChunk(); bool AdvanceChunk();
// forward nnet with feats, and get nnet output // forward nnet with feats, and get nnet output
bool AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs, bool AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim); int* vocab_dim);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps, void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight, float reverse_weight,
std::vector<float>* rescoring_score); std::vector<float>* rescoring_score);
virtual bool IsLastFrame(int32 frame); virtual bool IsLastFrame(int32 frame);

@ -56,9 +56,9 @@ class PaddleNnet : public NnetInterface {
NnetOut* out) override; NnetOut* out) override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps, void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight, float reverse_weight,
std::vector<float>* rescoring_score) override { std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring."; VLOG(2) << "deepspeech2 not has AttentionRescoring.";
} }
void Dim(); void Dim();

Loading…
Cancel
Save