make wfst work & align frame

pull/1599/head
Yang Zhou 3 years ago
parent 642e0840b4
commit 2455d88997

@ -27,14 +27,20 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table"); DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 10.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 5000, "decoder graph"); DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test clg decoder by feeding speech feature.
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
@ -52,7 +58,8 @@ int main(int argc, char* argv[]) {
opts.word_symbol_table = word_symbol_table; opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path; opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active; opts.opts.max_active = FLAGS_max_active;
opts.opts.beam = opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts); ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
@ -61,30 +68,55 @@ int main(int argc, char* argv[]) {
model_opts.cache_shape = "5-1-1024,5-1-1024"; model_opts.cache_shape = "5-1-1024,5-1-1024";
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data( std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data)); new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = 35; int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols()); raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0; int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size; int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size * kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols()); feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx); kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp( kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(), feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols()); feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp); f_chunk_tmp.CopyFromVec(tmp);
row_idx++; ++start;
} }
raw_data->Accept(feature_chunk); raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) { if (chunk_idx == num_chunks - 1) {

@ -1,50 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
namespace ppspeech { namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path)); fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr); CHECK(fst_ != nullptr);
word_symbol_table_.reset(fst::SymbolTable::ReadText(opts.word_symbol_table)); word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0;
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::InitDecoder() {
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0;
} }
void TLGDecoder::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable) { void TLGDecoder::AdvanceDecode(
while (1) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
AdvanceDecoding(decodable.get()); while (!decodable->IsLastFrame(frame_decoded_size_)) {
if (decodable->IsLastFrame(num_frame_decoded_)) break; LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
} }
} }
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
// skip blank frame? decoder_->AdvanceDecoding(decodable, 1);
decoder_->AdvanceDecoding(decodable, 1); frame_decoded_size_++;
num_frame_decoded_++;
} }
void TLGDecoder::Reset() { void TLGDecoder::Reset() {
decoder_->InitDecoding(); InitDecoder();
return; return;
} }
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding(); decoder_->FinalizeDecoding();
kaldi::Lattice lat; kaldi::Lattice lat;
kaldi::LatticeWeight weight; kaldi::LatticeWeight weight;
std::vector<int> alignment; std::vector<int> alignment;
std::vector<int> words_id; std::vector<int> words_id;
decoder_->GetBestPath(&lat, true); decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words; std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) { for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]); std::string word = word_symbol_table_->Find(words_id[idx]);
words += word; words += word;
} }
return words; return words;
} }
} }

@ -1,21 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "base/basic_types.h"
#include "kaldi/decoder/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
#include "base/basic_types.h"
namespace ppspeech { namespace ppspeech {
struct TLGDecoderOptions { struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts; kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table; std::string word_symbol_table;
std::string fst_path; std::string fst_path;
TLGDecoderOptions() TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
: word_symbol_table(""),
fst_path("") {}
}; };
class TLGDecoder { class TLGDecoder {
@ -34,14 +46,14 @@ class TLGDecoder {
void Reset(); void Reset();
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_; std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_; std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_; std::shared_ptr<fst::SymbolTable> word_symbol_table_;
int32 num_frame_decoded_; // the frame size which have decoded starts from 0.
}; int32 frame_decoded_size_;
};
} // namespace ppspeech } // namespace ppspeech

@ -22,8 +22,13 @@ using std::vector;
using kaldi::Vector; using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend) const std::shared_ptr<FrontendInterface>& frontend,
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {} kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),
nnet_(nnet),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_cache_ = likelihood; nnet_cache_ = likelihood;
@ -32,14 +37,14 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
// Decodable::Init(DecodableConfig config) { // Decodable::Init(DecodableConfig config) {
//} //}
int32 Decodable::NumFramesReady() const {
return frames_ready_;
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) { bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame); bool flag = EnsureFrameHaveComputed(frame);
//CHECK_LE(frame, frames_ready_); return frame >= frames_ready_;
return (flag == false) && (frame == frames_ready_);
} }
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
@ -48,7 +53,8 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols()); CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_; int32 frame_idx = frame - frame_offset_;
return std::log(nnet_cache_(frame_idx, index) + std::numeric_limits<float>::min()); return acoustic_scale_ * std::log(nnet_cache_(frame_idx, index - 1) +
std::numeric_limits<float>::min());
} }
bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::EnsureFrameHaveComputed(int32 frame) {
@ -67,20 +73,12 @@ bool Decodable::AdvanceChunk() {
Vector<BaseFloat> inferences; Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp; Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_tmp.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_tmp.CopyRowsFromVec(inferences); nnet_cache_.CopyRowsFromVec(inferences);
// skip blank
vector<int> no_blank_record;
BaseFloat blank_threshold = 0.98;
for (int32 idx = 0; idx < nnet_cache_.NumRows(); ++idx) {
if (nnet_cache_(idx, 0) > blank_threshold) {
}
}
frame_offset_ = frames_ready_; frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
LOG(INFO) << "nnet size: " << nnet_cache_.NumRows();
return true; return true;
} }
@ -89,7 +87,8 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
if (EnsureFrameHaveComputed(frame) == false) return false; if (EnsureFrameHaveComputed(frame) == false) return false;
likelihood->resize(nnet_cache_.NumCols()); likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
(*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); (*likelihood)[idx] =
nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_;
} }
return true; return true;
} }

@ -14,8 +14,8 @@
#include "base/common.h" #include "base/common.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/decoder/decodable-itf.h" #include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h" #include "nnet/nnet_itf.h"
namespace ppspeech { namespace ppspeech {
@ -25,7 +25,8 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet, explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend); const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame); virtual bool IsLastFrame(int32 frame);
@ -38,14 +39,12 @@ class Decodable : public kaldi::DecodableInterface {
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame); bool EnsureFrameHaveComputed(int32 frame);
private: private:
bool AdvanceChunk(); bool AdvanceChunk();
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
int32 frame_offset_; int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame // todo: feature frame mismatch with nnet inference frame
@ -53,6 +52,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame // so use subsampled_frame
int32 current_log_post_subsampled_offset_; int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_; int32 num_chunk_computed_;
kaldi::BaseFloat acoustic_scale_;
}; };
} // namespace ppspeech } // namespace ppspeech

Loading…
Cancel
Save