Revert "align nnet decoder & refactor"

pull/1562/head
YangZhou 3 years ago committed by GitHub
parent bb07144ca1
commit 5383dff250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR) cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

@ -17,75 +17,50 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/paddle_nnet.h"
DEFINE_string(feature_respecifier, "", "test feature rspecifier"); DEFINE_string(feature_respecifier, "", "test nnet prob");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// void SplitFeature(kaldi::Matrix<BaseFloat> feature,
// int32 chunk_size,
// std::vector<kaldi::Matrix<BaseFloat>* feature_chunks) {
//}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier); FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
// test nnet_output --> decoder result
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file;
opts.lm_path = lm_path;
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data)); new ppspeech::Decodable(nnet));
int32 chunk_size = 35; // int32 chunk_size = 35;
decoder.InitDecoder(); decoder.InitDecoder();
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols()); decodable->FeedFeatures(feature);
int32 row_idx = 0; decoder.AdvanceDecode(decodable, 8);
int32 num_chunks = feature.NumRows() / chunk_size; decodable->InputFinished();
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result; std::string result;
result = decoder.GetFinalBestPath(); result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result; KALDI_LOG << " the result of " << utt << " is " << result;
@ -96,4 +71,4 @@ int main(int argc, char* argv[]) {
KALDI_LOG << "Done " << num_done << " utterances, " << num_err KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -79,19 +79,21 @@ void CTCBeamSearch::Decode(
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode( void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable,
while (1) { int max_frames) {
while (max_frames > 0) {
vector<vector<BaseFloat>> likelihood; vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob; if (decodable->IsLastFrame(NumFrameDecoded() + 1)) {
bool flag = break;
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); }
if (flag == false) break; likelihood.push_back(
likelihood.push_back(frame_prob); decodable->FrameLogLikelihood(NumFrameDecoded() + 1));
AdvanceDecoding(likelihood); AdvanceDecoding(likelihood);
max_frames--;
} }
} }

@ -32,8 +32,8 @@ struct CTCBeamSearchOptions {
int cutoff_top_n; int cutoff_top_n;
int num_proc_bsearch; int num_proc_bsearch;
CTCBeamSearchOptions() CTCBeamSearchOptions()
: dict_file("vocab.txt"), : dict_file("./model/words.txt"),
lm_path("lm.klm"), lm_path("./model/lm.arpa"),
alpha(1.9f), alpha(1.9f),
beta(5.0), beta(5.0),
beam_size(300), beam_size(300),
@ -68,7 +68,8 @@ class CTCBeamSearch {
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); std::vector<std::string>& nbest_words);
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable,
int max_frames);
void Reset(); void Reset();
private: private:
@ -82,7 +83,8 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later // std::vector<DecodeResult> decoder_results_;
std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id; size_t blank_id;
int space_id; int space_id;
std::shared_ptr<PathTrie> root; std::shared_ptr<PathTrie> root;

@ -18,8 +18,6 @@
#include "base/common.h" #include "base/common.h"
#include "frontend/feature_extractor_interface.h" #include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech { namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface { class RawAudioCache : public FeatureExtractorInterface {
@ -47,12 +45,13 @@ class RawAudioCache : public FeatureExtractorInterface {
DISALLOW_COPY_AND_ASSIGN(RawAudioCache); DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
}; };
// it is a datasource for testing different frontend module. // it is a data source to test different frontend module.
// it accepts waves or feats. // it Accepts waves or feats.
class RawDataCache : public FeatureExtractorInterface { class RawDataCache: public FeatureExtractorInterface {
public: public:
explicit RawDataCache() { finished_ = false; } explicit RawDataCache() { finished_ = false; }
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { virtual void Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs; data_ = inputs;
} }
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) { virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
@ -63,15 +62,14 @@ class RawDataCache : public FeatureExtractorInterface {
data_.Resize(0); data_.Resize(0);
return true; return true;
} }
virtual size_t Dim() const { return dim_; } //the dim is data_ length
virtual size_t Dim() const { return data_.Dim(); }
virtual void SetFinished() { finished_ = true; } virtual void SetFinished() { finished_ = true; }
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
void SetDim(int32 dim) { dim_ = dim; }
private: private:
kaldi::Vector<kaldi::BaseFloat> data_; kaldi::Vector<kaldi::BaseFloat> data_;
bool finished_; bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache); DISALLOW_COPY_AND_ASSIGN(RawDataCache);
}; };

@ -1,3 +1,17 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// itf/decodable-itf.h // itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University; // Copyright 2009-2011 Microsoft Corporation; Saarland University;
@ -42,8 +56,10 @@ namespace kaldi {
For online decoding, where the features are coming in in real time, it is For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions. important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/, There are two ways these are used: the old online-decoding code, in
and the new online-decoding code, in ../online2/. In the old online-decoding ../online/,
and the new online-decoding code, in ../online2/. In the old
online-decoding
code, the decoder would do: code, the decoder would do:
\code{.cc} \code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
@ -52,13 +68,16 @@ namespace kaldi {
\endcode \endcode
and the call to IsLastFrame would block if the features had not arrived yet. and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for online-decoding mode is still supported, it is what happens when you call,
for
example, LatticeFasterDecoder::Decode(). example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example) control endpointing. In the "new" decoding code, you don't call (for
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), example)
LatticeFasterDecoder::Decode(), you call
LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this: something like this:
@ -68,7 +87,8 @@ namespace kaldi {
} }
\endcode \endcode
So the decodable object never has IsLastFrame() called. For decoding where So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will you are starting with a matrix of features, the NumFramesReady() function
will
always just return the number of frames in the file, and IsLastFrame() will always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame. return true for the last frame.
@ -80,45 +100,52 @@ namespace kaldi {
frame of the file once we've decided to terminate decoding. frame of the file once we've decided to terminate decoding.
*/ */
class DecodableInterface { class DecodableInterface {
public: public:
/// Returns the log likelihood, which will be negated in the decoder. /// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that NumFramesReady() > frame /// The "frame" starts from zero. You should verify that NumFramesReady() >
/// before calling this. /// frame
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; /// before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file /// Returns true if this is the last frame. Frames are zero-based, so the
/// is empty (which is a case that I'm not sure all the code will handle, so /// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// be careful). Caution: the behavior of this function in an online setting /// is empty (which is a case that I'm not sure all the code will handle, so
/// is being changed somewhat. In future it may return false in cases where /// be careful). Caution: the behavior of this function in an online
/// we haven't yet decided to terminate decoding, but later true if we decide /// setting
/// to terminate decoding. The plan in future is to rely more on /// is being changed somewhat. In future it may return false in cases where
/// NumFramesReady(), and in future, IsLastFrame() would always return false /// we haven't yet decided to terminate decoding, but later true if we
/// in an online-decoding setting, and would only return true in a /// decide
/// decoding-from-matrix setting where we want to allow the last delta or LDA /// to terminate decoding. The plan in future is to rely more on
/// features to be flushed out for compatibility with the baseline setup. /// NumFramesReady(), and in future, IsLastFrame() would always return false
virtual bool IsLastFrame(int32 frame) const = 0; /// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or
/// The call NumFramesReady() will return the number of frames currently available /// LDA
/// for this decodable object. This is for use in setups where you don't want the /// features to be flushed out for compatibility with the baseline setup.
/// decoder to block while waiting for input. This is newly added as of Jan 2014, virtual bool IsLastFrame(int32 frame) const = 0;
/// and I hope, going forward, to rely on this mechanism more than IsLastFrame to
/// know when to stop decoding. /// The call NumFramesReady() will return the number of frames currently
virtual int32 NumFramesReady() const { /// available
KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; /// for this decodable object. This is for use in setups where you don't
return -1; /// want the
} /// decoder to block while waiting for input. This is newly added as of Jan
/// 2014,
/// Returns the number of states in the acoustic model /// and I hope, going forward, to rely on this mechanism more than
/// (they will be indexed one-based, i.e. from 1 to NumIndices(); /// IsLastFrame to
/// this is for compatibility with OpenFst). /// know when to stop decoding.
virtual int32 NumIndices() const = 0; virtual int32 NumFramesReady() const {
KALDI_ERR
virtual bool FrameLogLikelihood(int32 frame, << "NumFramesReady() not implemented for this decodable type.";
std::vector<kaldi::BaseFloat>* likelihood) = 0; return -1;
}
virtual ~DecodableInterface() {} /// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame) = 0;
virtual ~DecodableInterface() {}
}; };
/// @} /// @}
} // namespace Kaldi } // namespace Kaldi

@ -18,16 +18,9 @@ namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector;
using kaldi::Vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet)
const std::shared_ptr<FeatureExtractorInterface>& frontend) : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {}
: frontend_(frontend),
nnet_(nnet),
finished_(false),
frame_offset_(0),
frames_ready_(0) {}
void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) { void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
frames_ready_ += likelihood.NumRows(); frames_ready_ += likelihood.NumRows();
@ -38,46 +31,26 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
bool Decodable::IsLastFrame(int32 frame) const { bool Decodable::IsLastFrame(int32 frame) const {
CHECK_LE(frame, frames_ready_); CHECK_LE(frame, frames_ready_);
return IsInputFinished() && (frame == frames_ready_ - 1); return finished_ && (frame == frames_ready_ - 1);
} }
int32 Decodable::NumIndices() const { return 0; } int32 Decodable::NumIndices() const { return 0; }
BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; }
CHECK_LE(index, nnet_cache_.NumCols());
return 0;
}
bool Decodable::EnsureFrameHaveComputed(int32 frame) { void Decodable::FeedFeatures(const Matrix<kaldi::BaseFloat>& features) {
if (frame >= frames_ready_) { nnet_->FeedForward(features, &nnet_cache_);
return AdvanceChunk();
}
return true;
}
bool Decodable::AdvanceChunk() {
Vector<BaseFloat> features;
if (frontend_->Read(&features) == false) {
return false;
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);
frame_offset_ = frames_ready_;
frames_ready_ += nnet_cache_.NumRows(); frames_ready_ += nnet_cache_.NumRows();
return true; return;
} }
bool Decodable::FrameLogLikelihood(int32 frame, vector<BaseFloat>* likelihood) { std::vector<BaseFloat> Decodable::FrameLogLikelihood(int32 frame) {
std::vector<BaseFloat> result; std::vector<BaseFloat> result;
if (EnsureFrameHaveComputed(frame) == false) return false; result.reserve(nnet_cache_.NumCols());
likelihood->resize(nnet_cache_.NumCols());
for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) {
(*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); result[idx] = nnet_cache_(frame, idx);
} }
return true; return result;
} }
void Decodable::Reset() { void Decodable::Reset() {

@ -24,35 +24,25 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable( explicit Decodable(const std::shared_ptr<NnetInterface>& nnet);
const std::shared_ptr<NnetInterface>& nnet,
const std::shared_ptr<FeatureExtractorInterface>& frontend);
// void Init(DecodableOpts config); // void Init(DecodableOpts config);
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);
virtual bool IsLastFrame(int32 frame) const; virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumIndices() const; virtual int32 NumIndices() const;
virtual bool FrameLogLikelihood(int32 frame, virtual std::vector<BaseFloat> FrameLogLikelihood(int32 frame);
std::vector<kaldi::BaseFloat>* likelihood); void Acceptlikelihood(
// for offline test const kaldi::Matrix<kaldi::BaseFloat>& likelihood); // remove later
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); void FeedFeatures(const kaldi::Matrix<kaldi::BaseFloat>&
feature); // only for test, todo remove later
void Reset(); void Reset();
bool IsInputFinished() const { return frontend_->IsFinished(); } void InputFinished() { finished_ = true; }
bool EnsureFrameHaveComputed(int32 frame);
private: private:
bool AdvanceChunk();
std::shared_ptr<FeatureExtractorInterface> frontend_; std::shared_ptr<FeatureExtractorInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetInterface> nnet_;
kaldi::Matrix<kaldi::BaseFloat> nnet_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_cache_;
// std::vector<std::vector<kaldi::BaseFloat>> nnet_cache_;
bool finished_; bool finished_;
int32 frame_offset_;
int32 frames_ready_; int32 frames_ready_;
// todo: feature frame mismatch with nnet inference frame
// eg: 35 frame features output 8 frame inferences
// so use subsampled_frame
int32 current_log_post_subsampled_offset_;
int32 num_chunk_computed_;
}; };
} // namespace ppspeech } // namespace ppspeech

@ -23,10 +23,8 @@ namespace ppspeech {
class NnetInterface { class NnetInterface {
public: public:
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
int32 feature_dim, kaldi::Matrix<kaldi::BaseFloat>* inferences) = 0;
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
virtual ~NnetInterface() {} virtual ~NnetInterface() {}
}; };

@ -21,7 +21,6 @@ using std::vector;
using std::string; using std::string;
using std::shared_ptr; using std::shared_ptr;
using kaldi::Matrix; using kaldi::Matrix;
using kaldi::Vector;
void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) {
std::vector<std::string> cache_names; std::vector<std::string> cache_names;
@ -144,27 +143,34 @@ shared_ptr<Tensor<BaseFloat>> PaddleNnet::GetCacheEncoder(const string& name) {
return cache_encouts_[iter->second]; return cache_encouts_[iter->second];
} }
void PaddleNnet::FeedForward(const Vector<BaseFloat>& features, void PaddleNnet::FeedForward(const Matrix<BaseFloat>& features,
int32 feature_dim, Matrix<BaseFloat>* inferences) {
Vector<BaseFloat>* inferences,
int32* inference_dim) {
paddle_infer::Predictor* predictor = GetPredictor(); paddle_infer::Predictor* predictor = GetPredictor();
int feat_row = features.Dim() / feature_dim; int row = features.NumRows();
int col = features.NumCols();
std::vector<BaseFloat> feed_feature;
// todo refactor feed feature: SmileGoat
feed_feature.reserve(row * col);
for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) {
for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) {
feed_feature.push_back(features(row_idx, col_idx));
}
}
std::vector<std::string> input_names = predictor->GetInputNames(); std::vector<std::string> input_names = predictor->GetInputNames();
std::vector<std::string> output_names = predictor->GetOutputNames(); std::vector<std::string> output_names = predictor->GetOutputNames();
LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim; LOG(INFO) << "feat info: row=" << row << ", col= " << col;
std::unique_ptr<paddle_infer::Tensor> input_tensor = std::unique_ptr<paddle_infer::Tensor> input_tensor =
predictor->GetInputHandle(input_names[0]); predictor->GetInputHandle(input_names[0]);
std::vector<int> INPUT_SHAPE = {1, feat_row, feature_dim}; std::vector<int> INPUT_SHAPE = {1, row, col};
input_tensor->Reshape(INPUT_SHAPE); input_tensor->Reshape(INPUT_SHAPE);
input_tensor->CopyFromCpu(features.Data()); input_tensor->CopyFromCpu(feed_feature.data());
std::unique_ptr<paddle_infer::Tensor> input_len = std::unique_ptr<paddle_infer::Tensor> input_len =
predictor->GetInputHandle(input_names[1]); predictor->GetInputHandle(input_names[1]);
std::vector<int> input_len_size = {1}; std::vector<int> input_len_size = {1};
input_len->Reshape(input_len_size); input_len->Reshape(input_len_size);
std::vector<int64_t> audio_len; std::vector<int64_t> audio_len;
audio_len.push_back(feat_row); audio_len.push_back(row);
input_len->CopyFromCpu(audio_len.data()); input_len->CopyFromCpu(audio_len.data());
std::unique_ptr<paddle_infer::Tensor> h_box = std::unique_ptr<paddle_infer::Tensor> h_box =
@ -197,12 +203,20 @@ void PaddleNnet::FeedForward(const Vector<BaseFloat>& features,
std::unique_ptr<paddle_infer::Tensor> output_tensor = std::unique_ptr<paddle_infer::Tensor> output_tensor =
predictor->GetOutputHandle(output_names[0]); predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape(); std::vector<int> output_shape = output_tensor->shape();
int32 row = output_shape[1]; row = output_shape[1];
int32 col = output_shape[2]; col = output_shape[2];
inferences->Resize(row * col); vector<float> inferences_result;
*inference_dim = col; inferences->Resize(row, col);
output_tensor->CopyToCpu(inferences->Data()); inferences_result.resize(row * col);
output_tensor->CopyToCpu(inferences_result.data());
ReleasePredictor(predictor); ReleasePredictor(predictor);
for (int row_idx = 0; row_idx < row; ++row_idx) {
for (int col_idx = 0; col_idx < col; ++col_idx) {
(*inferences)(row_idx, col_idx) =
inferences_result[col * row_idx + col_idx];
}
}
} }
} // namespace ppspeech } // namespace ppspeech

@ -39,8 +39,12 @@ struct ModelOptions {
bool enable_fc_padding; bool enable_fc_padding;
bool enable_profile; bool enable_profile;
ModelOptions() ModelOptions()
: model_path("avg_1.jit.pdmodel"), : model_path(
params_path("avg_1.jit.pdiparams"), "../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdmodel"),
params_path(
"../../../../model/paddle_online_deepspeech/model/"
"avg_1.jit.pdiparams"),
thread_num(2), thread_num(2),
use_gpu(false), use_gpu(false),
input_names( input_names(
@ -103,11 +107,8 @@ class Tensor {
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetInterface {
public: public:
PaddleNnet(const ModelOptions& opts); PaddleNnet(const ModelOptions& opts);
virtual void FeedForward(const kaldi::Vector<kaldi::BaseFloat>& features, virtual void FeedForward(const kaldi::Matrix<kaldi::BaseFloat>& features,
int32 feature_dim, kaldi::Matrix<kaldi::BaseFloat>* inferences);
kaldi::Vector<kaldi::BaseFloat>* inferences,
int32* inference_dim);
void Dim();
virtual void Reset(); virtual void Reset();
std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder( std::shared_ptr<Tensor<kaldi::BaseFloat>> GetCacheEncoder(
const std::string& name); const std::string& name);

Loading…
Cancel
Save