refactor ctc opts, extract decoder interface, add ctc beamsearch score

pull/2524/head
Hui Zhang 2 years ago
parent 5c8725e8cd
commit bc1b6c2e7c

@ -135,7 +135,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
tlg_decoder_main \
ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \

@ -133,7 +133,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
tlg_decoder_main \
ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \

@ -15,7 +15,7 @@ set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
recognizer_main
tlg_decoder_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS BINS)

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h"
#include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "utils/file_utils.h"
namespace ppspeech {
@ -26,7 +26,7 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts),
init_ext_scorer_(nullptr),
blank_id_(-1),
blank_id_(opts.blank),
space_id_(-1),
num_frame_decoded_(0),
root_(nullptr) {
@ -43,9 +43,9 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
}
blank_id_ = 0;
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
CHECK(blank_id_==0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin();
// if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) {

@ -14,67 +14,48 @@
// used by deepspeech2
#include "base/common.h"
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"
#pragma once
#include "decoder/decoder_itf.h"
namespace ppspeech {
struct CTCBeamSearchOptions {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch {
class CTCBeamSearch : public DecoderInterface {
public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {}
void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
@ -93,4 +74,4 @@ class CTCBeamSearch {
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
};
} // namespace basr
} // namespace ppspeech

@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "util/parse-options.h"
#pragma once
namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",
&second_beam_size,
module + "second beam size.");
}
};
} // namespace ppspeech

@ -0,0 +1,13 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
class CTCPrefixBeamSearch : public DecoderInterface {
public:
explicit CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
void InitDecoder();
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff);
void CalculateApproxScore();
void LMRescore();
void AdvanceDecoding(const std::vector<std::vector<BaseFloat>>& probs);
CTCBeamSearchOptions opts_;
size_t blank_id_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
};
} // namespace basr

@ -0,0 +1,68 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "utils/math.h"
namespace ppspeech {
struct PrefxiScore {
// decoding, unit in log scale
float b = -kFloatMax; // blank ending score
float nb = -kFloatMax; // none-blank ending score
// timestamp, unit in log sclae
float v_b = -kFloatMax; // viterbi blank ending score
float v_nb = -kFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
// context state
bool has_context = false;
int context_state = 0;
float context_score = 0;
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// decodign score with context bias
float TotalScore() const { return Score() + context_score; }
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
}
};
struct PrefixScoreHash {
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std::size_t operator()(const std::vector<int>& prefix) const {
std::size_t seed = prefix.size();
for (auto& i : prefix) {
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
} // namespace ppspeech

@ -22,24 +22,24 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0;
num_frame_decoded_ = 0;
}
void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
frame_decoded_size_ = 0;
num_frame_decoded_ = 0;
}
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
while (!decodable->IsLastFrame(num_frame_decoded_)) {
AdvanceDecoding(decodable.get());
}
}
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++;
num_frame_decoded_++;
}
void TLGDecoder::Reset() {
@ -48,7 +48,7 @@ void TLGDecoder::Reset() {
}
std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
@ -68,7 +68,7 @@ std::string TLGDecoder::GetPartialResult() {
}
std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");

@ -14,8 +14,9 @@
#pragma once
#include "base/basic_types.h"
#include "kaldi/decoder/decodable-itf.h"
#include "base/common.h"
#include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
@ -30,21 +31,31 @@ struct TLGDecoderOptions {
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
};
class TLGDecoder {
class TLGDecoder : public DecoderInterface {
public:
explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
std::string GetFinalBestPath();
std::string GetPartialResult();
void Decode();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
@ -53,7 +64,7 @@ class TLGDecoder {
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
int32 num_frame_decoded_;
};

@ -14,13 +14,15 @@
// todo refactor, repalce with gtest
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "base/common.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
#include "decoder/ctc_tlg_decoder.h"
#include "kaldi/util/table-types.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");

@ -0,0 +1,56 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
class DecoderInterface {
public:
virtual ~DecoderInterface() {}
virtual void InitDecoder() = 0;
virtual void Reset() = 0;
virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0;
// void Decode();
// std::string GetBestPath();
// std::vector<std::pair<double, std::string>> GetNBestPath();
// int NumFrameDecoded();
// int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
// std::vector<std::string>& nbest_words);
private:
// void AdvanceDecoding(kaldi::DecodableInterface* decodable);
// current decoding frame number
int32 num_frame_decoded_;
};
} // namespace ppspeech

@ -39,14 +39,14 @@ class Decodable : public kaldi::DecodableInterface {
// forward nnet with feats
bool AdvanceChunk();
// forward nnet with feats, and get nnet output
bool AdvanceChunk(kaldi::Vector<kaldi::BaseFloat>* logprobs,
int* vocab_dim);
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score);
float reverse_weight,
std::vector<float>* rescoring_score);
virtual bool IsLastFrame(int32 frame);

@ -56,9 +56,9 @@ class PaddleNnet : public NnetInterface {
NnetOut* out) override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
float reverse_weight,
std::vector<float>* rescoring_score) override {
VLOG(2) << "deepspeech2 not has AttentionRescoring.";
}
void Dim();

Loading…
Cancel
Save