add wfst decoder (#2886)

pull/2923/head
YangZhou 3 years ago committed by GitHub
parent 5042a1686a
commit 21183d48b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})
# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")

@ -1,6 +1,7 @@
set(srcs)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
)
add_library(decoder STATIC ${srcs})
@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)

@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase {
void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return unit_table_;
}
@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
}
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;

@ -15,7 +15,7 @@
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() {
return words;
}
void TLGDecoder::FinalizeSearch() {
decoder_->FinalizeDecoding();
kaldi::CompactLattice clat;
decoder_->GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
std::vector<kaldi::Lattice> nbest_lats;
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
hypotheses_.clear();
hypotheses_.reserve(nbest_lats.size());
likelihood_.clear();
likelihood_.reserve(nbest_lats.size());
times_.clear();
times_.reserve(nbest_lats.size());
for (auto lat : nbest_lats) {
kaldi::LatticeWeight weight;
std::vector<int> hypothese;
std::vector<int> time;
std::vector<int> alignment;
std::vector<int> words_id;
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
int idx = 0;
for (; idx < alignment.size() - 1; ++idx) {
if (alignment[idx] == 0) continue;
if (alignment[idx] != alignment[idx + 1]) {
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
}
}
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
hypotheses_.push_back(hypothese);
times_.push_back(time);
olabels.push_back(words_id);
likelihood_.push_back(-(weight.Value2() + weight.Value1()));
}
}
std::string TLGDecoder::GetFinalBestPath() {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;

@ -19,9 +19,8 @@
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
@ -33,6 +32,9 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
int nbest;
TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {}
static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
@ -44,6 +46,7 @@ struct TLGDecoderOptions {
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
// decoder_opts.nbest = FLAGS_lattice_nbest;
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase {
explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder();
void Reset();
void InitDecoder() override;
void Reset() override;
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
void Decode();
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return word_symbol_table_;
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
void FinalizeSearch() override;
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return olabels;
} // outputs_; }
const std::vector<float>& Likelihood() const override {
return likelihood_;
}
const std::vector<std::vector<int>>& Times() const override {
return times_;
}
protected:
std::string GetBestPath() override {
CHECK(false);
@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase {
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
std::vector<std::vector<int>> hypotheses_;
std::vector<std::vector<int>> olabels;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
TLGDecoderOptions opts_;
};

@ -14,16 +14,16 @@
// todo refactor, repalce with gtest
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
#include "nnet/nnet_producer.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
@ -39,8 +39,8 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader(
FLAGS_nnet_prob_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
@ -53,66 +53,19 @@ int main(int argc, char* argv[]) {
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) {
string utt = nnet_prob_reader.Key();
kaldi::Matrix<BaseFloat> prob = nnet_prob_reader.Value();
decodable->Acceptlikelihood(prob);
decoder.AdvanceDecode(decodable);
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();

@ -1,4 +1,3 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@ -16,6 +15,7 @@
#pragma once
#include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
@ -41,6 +41,14 @@ class DecoderInterface {
virtual std::string GetPartialResult() = 0;
virtual const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const = 0;
virtual void FinalizeSearch() = 0;
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
virtual const std::vector<float>& Likelihood() const = 0;
virtual const std::vector<std::vector<int>>& Times() const = 0;
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;

@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");

@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface {
explicit Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config);
// nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);

@ -23,25 +23,25 @@ using kaldi::BaseFloat;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {
abort_ = false;
Reset();
thread_ = std::thread(RunNnetEvaluation, this);
}
abort_ = false;
Reset();
if (nnet_ != nullptr) thread_ = std::thread(RunNnetEvaluation, this);
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
condition_variable_.notify_one();
}
void NnetProducer::UnLock() {
void NnetProducer::WaitProduce() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
condition_read_ready_.wait(lock);
}
return;
}
void NnetProducer::RunNnetEvaluation(NnetProducer *me) {
void NnetProducer::RunNnetEvaluation(NnetProducer* me) {
me->RunNnetEvaluationInteral();
}
@ -55,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() {
result = Compute();
} while (result);
if (frontend_->IsFinished() == true) {
if (cache_.empty()) finished_ = true;
if (cache_.empty()) finished_ = true;
}
}
LOG(INFO) << "NnetEvaluationInteral exit";

@ -34,9 +34,9 @@ class NnetProducer {
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob);
static void RunNnetEvaluation(NnetProducer *me);
static void RunNnetEvaluation(NnetProducer* me);
void RunNnetEvaluationInteral();
void UnLock();
void WaitProduce();
void Wait() {
abort_ = true;
@ -56,12 +56,12 @@ class NnetProducer {
bool IsFinished() const { return finished_; }
~NnetProducer() {
if (thread_.joinable()) thread_.join();
if (thread_.joinable()) thread_.join();
}
void Reset() {
frontend_->Reset();
nnet_->Reset();
if (frontend_ != NULL) frontend_->Reset();
if (nnet_ != NULL) nnet_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
cache_.clear();
finished_ = false;

@ -33,11 +33,15 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") {
LOG(INFO) << resource.decoder_opts.tlg_decoder_opts.fst_path;
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
} else {
decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts));
}
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
@ -56,11 +60,14 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource,
decodable_.reset(new Decodable(nnet_producer_, am_scale));
CHECK_NE(resource.vocab_path, "");
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
if (resource.decoder_opts.tlg_decoder_opts.fst_path == "") {
decoder_.reset(new CTCPrefixBeamSearch(
resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts));
} else {
decoder_.reset(new TLGDecoder(resource.decoder_opts.tlg_decoder_opts));
}
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
symbol_table_ = decoder_->WordSymbolTable();
global_frame_offset_ = 0;
input_finished_ = false;
@ -109,10 +116,11 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
void U2Recognizer::RunDecoderSearchInternal() {
LOG(INFO) << "DecoderSearchInteral begin";
while (!nnet_producer_->IsFinished()) {
nnet_producer_->UnLock();
nnet_producer_->WaitProduce();
decoder_->AdvanceDecode(decodable_);
}
Decode();
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
LOG(INFO) << "DecoderSearchInteral exit";
}
@ -140,7 +148,7 @@ void U2Recognizer::UpdateResult(bool finish) {
const auto& times = decoder_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
CHECK_EQ(inputs.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
@ -148,13 +156,9 @@ void U2Recognizer::UpdateResult(bool finish) {
path.score = likelihood[i];
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (decoder_->Type() == kWfstBeamSearch) {
path.sentence += (" " + word);
} else {
path.sentence += (word);
}
// path.sentence += (" " + word); // todo SmileGoat: add blank
// processor
path.sentence += word; // todo SmileGoat: add blank processor
}
// TimeStamp is only supported in final result
@ -162,7 +166,7 @@ void U2Recognizer::UpdateResult(bool finish) {
// various FST operations when building the decoding graph. So here we
// use time stamp of the input(e2e model unit), which is more accurate,
// and it requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
if (symbol_table_ != nullptr && finish) {
int offset = global_frame_offset_ * FrameShiftInMs();
const std::vector<int>& input = inputs[i];
@ -170,7 +174,7 @@ void U2Recognizer::UpdateResult(bool finish) {
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
std::string word = symbol_table_->Find(input[j]);
int start =
time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0
@ -214,7 +218,7 @@ void U2Recognizer::UpdateResult(bool finish) {
void U2Recognizer::AttentionRescoring() {
decoder_->FinalizeSearch();
UpdateResult(true);
UpdateResult(false);
// No need to do rescoring
if (0.0 == opts_.decoder_opts.rescoring_weight) {

@ -17,6 +17,7 @@
#include "decoder/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/decoder_itf.h"
#include "frontend/feature_pipeline.h"
#include "fst/fstlib.h"
@ -33,6 +34,8 @@ DECLARE_int32(blank);
DECLARE_double(acoustic_scale);
DECLARE_string(vocab_path);
DECLARE_string(word_symbol_table);
// DECLARE_string(fst_path);
namespace ppspeech {
@ -59,6 +62,7 @@ struct DecodeOptions {
// CtcEndpointConfig ctc_endpoint_opts;
CTCBeamSearchOptions ctc_prefix_search_opts{};
TLGDecoderOptions tlg_decoder_opts{};
static DecodeOptions InitFromFlags() {
DecodeOptions decoder_opts;
@ -70,6 +74,13 @@ struct DecodeOptions {
decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank;
decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
// decoder_opts.tlg_decoder_opts.fst_path = "";//FLAGS_fst_path;
// decoder_opts.tlg_decoder_opts.word_symbol_table =
// FLAGS_word_symbol_table;
// decoder_opts.tlg_decoder_opts.nbest = FLAGS_nbest;
decoder_opts.tlg_decoder_opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size;
LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks;
LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight;
@ -113,7 +124,7 @@ class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
explicit U2Recognizer(const U2RecognizerResource& resource,
std::shared_ptr<NnetBase> nnet);
std::shared_ptr<NnetBase> nnet);
~U2Recognizer();
void InitDecoder();
void ResetContinuousDecoding();
@ -154,10 +165,9 @@ class U2Recognizer {
std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
std::unique_ptr<DecoderBase> decoder_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::vector<DecodeResult> result_;

Loading…
Cancel
Save