make wfst work & align frame

pull/1599/head
Yang Zhou 2 years ago
parent 642e0840b4
commit 2455d88997

@ -27,14 +27,20 @@ DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 10.0, "acoustic scale");
DEFINE_int32(max_active, 5000, "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test clg decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
@ -52,7 +58,8 @@ int main(int argc, char* argv[]) {
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
opts.opts.beam =
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
@ -61,30 +68,55 @@ int main(int argc, char* argv[]) {
model_opts.cache_shape = "5-1-1024,5-1-1024";
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(
new ppspeech::DataCache());
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = 35;
int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {

@ -1,50 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
word_symbol_table_.reset(fst::SymbolTable::ReadText(opts.word_symbol_table));
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}
void TLGDecoder::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
AdvanceDecoding(decodable.get());
if (decodable->IsLastFrame(num_frame_decoded_)) break;
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
// skip blank frame?
decoder_->AdvanceDecoding(decodable, 1);
num_frame_decoded_++;
decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++;
}
void TLGDecoder::Reset() {
decoder_->InitDecoding();
return;
InitDecoder();
return;
}
std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
}

@ -1,21 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "base/basic_types.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"
#include "base/basic_types.h"
namespace ppspeech {
struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
TLGDecoderOptions()
: word_symbol_table(""),
fst_path("") {}
kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {}
};
class TLGDecoder {
@ -34,14 +46,14 @@ class TLGDecoder {
void Reset();
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
void AdvanceDecoding(kaldi::DecodableInterface* decodable);
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
int32 num_frame_decoded_;
};
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
};
} // namespace ppspeech

@ -22,8 +22,13 @@ using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend)
: frontend_(frontend), nnet_(nnet), frame_offset_(0), frames_ready_(0) {}
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),
nnet_(nnet),
frame_offset_(0),
frames_ready_(0),
acoustic_scale_(acoustic_scale) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_cache_ = likelihood;
@ -32,14 +37,14 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
// Decodable::Init(DecodableConfig config) {
//}
int32 Decodable::NumFramesReady() const {
return frames_ready_;
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }
// frame idx is from 0 to frame_ready_ -1;
bool Decodable::IsLastFrame(int32 frame) {
bool flag = EnsureFrameHaveComputed(frame);
//CHECK_LE(frame, frames_ready_);
return (flag == false) && (frame == frames_ready_);
return frame >= frames_ready_;
}
int32 Decodable::NumIndices() const { return 0; }
@ -48,7 +53,8 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) {
CHECK_LE(index, nnet_cache_.NumCols());
CHECK_LE(frame, frames_ready_);
int32 frame_idx = frame - frame_offset_;
return std::log(nnet_cache_(frame_idx, index) + std::numeric_limits<float>::min());
return acoustic_scale_ * std::log(nnet_cache_(frame_idx, index - 1) +
std::numeric_limits<float>::min());
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) {
@ -67,20 +73,12 @@ bool Decodable::AdvanceChunk() {
Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_tmp.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_tmp.CopyRowsFromVec(inferences);
// skip blank
vector<int> no_blank_record;
BaseFloat blank_threshold = 0.98;
for (int32 idx = 0; idx < nnet_cache_.NumRows(); ++idx) {
if (nnet_cache_(idx, 0) > blank_threshold) {
}
}
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows();
LOG(INFO) << "nnet size: " << nnet_cache_.NumRows();
return true;
}
@ -89,7 +87,8 @@ bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) {
if (EnsureFrameHaveComputed(frame) == false) return false;
likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
(*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx);
(*likelihood)[idx] =
nnet_cache_(frame - frame_offset_, idx) * acoustic_scale_;
}
return true;
}

@ -14,8 +14,8 @@
#include "base/common.h"
#include "frontend/audio/frontend_itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/decoder/decodable-itf.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "nnet/nnet_itf.h"
namespace ppspeech {
@ -25,7 +25,8 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FrontendInterface>& frontend);
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0);
// void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame);
@ -38,14 +39,12 @@ class Decodable : public kaldi::DecodableInterface {
void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); }
bool EnsureFrameHaveComputed(int32 frame);
private:
bool AdvanceChunk();
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
int32 frame_offset_;
int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame
@ -53,6 +52,7 @@ class Decodable : public kaldi::DecodableInterface {
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
kaldi::BaseFloat acoustic_scale_;
};
} // namespace ppspeech

Loading…
Cancel
Save