diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index cf90d094..4bd5c6cf 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) -add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc) -target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) \ No newline at end of file +add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) +target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) diff --git a/speechx/examples/decoder/offline-decoder-main.cc b/speechx/examples/decoder/offline_decoder_main.cc similarity index 52% rename from speechx/examples/decoder/offline-decoder-main.cc rename to speechx/examples/decoder/offline_decoder_main.cc index 8e6e7850..cffca39a 100644 --- a/speechx/examples/decoder/offline-decoder-main.cc +++ b/speechx/examples/decoder/offline_decoder_main.cc @@ -17,50 +17,75 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" +#include "frontend/raw_audio.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" -DEFINE_string(feature_respecifier, "", "test nnet prob"); +DEFINE_string(feature_respecifier, "", "test feature rspecifier"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); +DEFINE_string(lm_path, "lm.klm", "language model"); + using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// void SplitFeature(kaldi::Matrix feature, -// int32 chunk_size, -// std::vector* feature_chunks) { - -//} - int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_respecifier); + std::string model_graph = FLAGS_model_path; + std::string model_params = FLAGS_param_path; + std::string dict_file = FLAGS_dict_file; + std::string lm_path = FLAGS_lm_path; - // test nnet_output --> decoder result int32 num_done = 0, num_err = 0; ppspeech::CTCBeamSearchOptions opts; + opts.dict_file = dict_file; + opts.lm_path = lm_path; ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; + model_opts.model_path = model_graph; + model_opts.params_path = model_params; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - + std::shared_ptr raw_data( + new ppspeech::RawDataCache()); std::shared_ptr decodable( - new ppspeech::Decodable(nnet)); + new ppspeech::Decodable(nnet, raw_data)); - // int32 chunk_size = 35; + int32 chunk_size = 35; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); const kaldi::Matrix feature = feature_reader.Value(); - decodable->FeedFeatures(feature); - decoder.AdvanceDecode(decodable, 8); - decodable->InputFinished(); + raw_data->SetDim(feature.NumCols()); + int32 row_idx = 0; + int32 num_chunks = feature.NumRows() / chunk_size; + for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + kaldi::Vector feature_chunk(chunk_size * + feature.NumCols()); + for (int row_id = 0; row_id < chunk_size; ++row_id) { + kaldi::SubVector tmp(feature, row_idx); + kaldi::SubVector f_chunk_tmp( + feature_chunk.Data() + row_id * feature.NumCols(), + feature.NumCols()); + f_chunk_tmp.CopyFromVec(tmp); + row_idx++; + } + raw_data->Accept(feature_chunk); + if (chunk_idx == num_chunks - 1) { + raw_data->SetFinished(); + } + decoder.AdvanceDecode(decodable); + } std::string result; result = decoder.GetFinalBestPath(); KALDI_LOG << " the result of " << utt << " is " << result; @@ -71,4 +96,4 @@ int main(int argc, char* argv[]) { KALDI_LOG << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); -} \ No newline at end of file +} diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 92c57858..8106b710 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -79,21 +79,19 @@ void CTCBeamSearch::Decode( return; } -int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; } +int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } // todo rename, refactor void CTCBeamSearch::AdvanceDecode( - const std::shared_ptr& decodable, - int max_frames) { - while (max_frames > 0) { + const std::shared_ptr& decodable) { + while (1) { vector> likelihood; - if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { - break; - } - likelihood.push_back( - decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); + vector frame_prob; + bool flag = + decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); + if (flag == false) break; + likelihood.push_back(frame_prob); AdvanceDecoding(likelihood); - max_frames--; } } diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 1e6ac88b..53700e27 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -32,8 +32,8 @@ struct CTCBeamSearchOptions { int cutoff_top_n; int num_proc_bsearch; CTCBeamSearchOptions() - : dict_file("./model/words.txt"), - lm_path("./model/lm.arpa"), + : dict_file("vocab.txt"), + lm_path("lm.klm"), alpha(1.9f), beta(5.0), beam_size(300), @@ -68,8 +68,7 @@ class CTCBeamSearch { int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); void AdvanceDecode( - const std::shared_ptr& decodable, - int max_frames); + const std::shared_ptr& decodable); void Reset(); private: @@ -83,8 +82,7 @@ class CTCBeamSearch { CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later - // std::vector decoder_results_; - std::vector vocabulary_; // todo remove later + std::vector vocabulary_; // todo remove later size_t blank_id; int space_id; std::shared_ptr root; diff --git a/speechx/speechx/frontend/raw_audio.h b/speechx/speechx/frontend/raw_audio.h index 996b6e78..1a326b3c 100644 --- a/speechx/speechx/frontend/raw_audio.h +++ b/speechx/speechx/frontend/raw_audio.h @@ -18,6 +18,8 @@ #include "base/common.h" #include "frontend/feature_extractor_interface.h" +#pragma once + namespace ppspeech { class RawAudioCache : public FeatureExtractorInterface { @@ -45,13 +47,12 @@ class RawAudioCache : public FeatureExtractorInterface { DISALLOW_COPY_AND_ASSIGN(RawAudioCache); }; -// it is a data source to test different frontend module. -// it Accepts waves or feats. -class RawDataCache: public FeatureExtractorInterface { +// it is a datasource for testing different frontend module. +// it accepts waves or feats. +class RawDataCache : public FeatureExtractorInterface { public: explicit RawDataCache() { finished_ = false; } - virtual void Accept( - const kaldi::VectorBase& inputs) { + virtual void Accept(const kaldi::VectorBase& inputs) { data_ = inputs; } virtual bool Read(kaldi::Vector* feats) { @@ -62,14 +63,15 @@ class RawDataCache: public FeatureExtractorInterface { data_.Resize(0); return true; } - //the dim is data_ length - virtual size_t Dim() const { return data_.Dim(); } + virtual size_t Dim() const { return dim_; } virtual void SetFinished() { finished_ = true; } virtual bool IsFinished() const { return finished_; } + void SetDim(int32 dim) { dim_ = dim; } private: kaldi::Vector data_; bool finished_; + int32 dim_; DISALLOW_COPY_AND_ASSIGN(RawDataCache); }; diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/nnet/decodable-itf.h index 3ea9b557..37c3007b 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/nnet/decodable-itf.h @@ -1,17 +1,3 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - // itf/decodable-itf.h // Copyright 2009-2011 Microsoft Corporation; Saarland University; @@ -56,10 +42,8 @@ namespace kaldi { For online decoding, where the features are coming in in real time, it is important to understand the IsLastFrame() and NumFramesReady() functions. - There are two ways these are used: the old online-decoding code, in - ../online/, - and the new online-decoding code, in ../online2/. In the old - online-decoding + There are two ways these are used: the old online-decoding code, in ../online/, + and the new online-decoding code, in ../online2/. In the old online-decoding code, the decoder would do: \code{.cc} for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { @@ -68,16 +52,13 @@ namespace kaldi { \endcode and the call to IsLastFrame would block if the features had not arrived yet. The decodable object would have to know when to terminate the decoding. This - online-decoding mode is still supported, it is what happens when you call, - for + online-decoding mode is still supported, it is what happens when you call, for example, LatticeFasterDecoder::Decode(). We realized that this "blocking" mode of decoding is not very convenient because it forces the program to be multi-threaded and makes it complex to - control endpointing. In the "new" decoding code, you don't call (for - example) - LatticeFasterDecoder::Decode(), you call - LatticeFasterDecoder::InitDecoding(), + control endpointing. In the "new" decoding code, you don't call (for example) + LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), and then each time you get more features, you provide them to the decodable object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does something like this: @@ -87,8 +68,7 @@ namespace kaldi { } \endcode So the decodable object never has IsLastFrame() called. For decoding where - you are starting with a matrix of features, the NumFramesReady() function - will + you are starting with a matrix of features, the NumFramesReady() function will always just return the number of frames in the file, and IsLastFrame() will return true for the last frame. @@ -100,52 +80,45 @@ namespace kaldi { frame of the file once we've decided to terminate decoding. */ class DecodableInterface { - public: - /// Returns the log likelihood, which will be negated in the decoder. - /// The "frame" starts from zero. You should verify that NumFramesReady() > - /// frame - /// before calling this. - virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; - - /// Returns true if this is the last frame. Frames are zero-based, so the - /// first frame is zero. IsLastFrame(-1) will return false, unless the file - /// is empty (which is a case that I'm not sure all the code will handle, so - /// be careful). Caution: the behavior of this function in an online - /// setting - /// is being changed somewhat. In future it may return false in cases where - /// we haven't yet decided to terminate decoding, but later true if we - /// decide - /// to terminate decoding. The plan in future is to rely more on - /// NumFramesReady(), and in future, IsLastFrame() would always return false - /// in an online-decoding setting, and would only return true in a - /// decoding-from-matrix setting where we want to allow the last delta or - /// LDA - /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; - - /// The call NumFramesReady() will return the number of frames currently - /// available - /// for this decodable object. This is for use in setups where you don't - /// want the - /// decoder to block while waiting for input. This is newly added as of Jan - /// 2014, - /// and I hope, going forward, to rely on this mechanism more than - /// IsLastFrame to - /// know when to stop decoding. - virtual int32 NumFramesReady() const { - KALDI_ERR - << "NumFramesReady() not implemented for this decodable type."; - return -1; - } - - /// Returns the number of states in the acoustic model - /// (they will be indexed one-based, i.e. from 1 to NumIndices(); - /// this is for compatibility with OpenFst). - virtual int32 NumIndices() const = 0; - - virtual std::vector FrameLogLikelihood(int32 frame) = 0; - - virtual ~DecodableInterface() {} + public: + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that NumFramesReady() > frame + /// before calling this. + virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32 frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently available + /// for this decodable object. This is for use in setups where you don't want the + /// decoder to block while waiting for input. This is newly added as of Jan 2014, + /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to + /// know when to stop decoding. + virtual int32 NumFramesReady() const { + KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst). + virtual int32 NumIndices() const = 0; + + virtual bool FrameLogLikelihood(int32 frame, + std::vector* likelihood) = 0; + + + virtual ~DecodableInterface() {} }; /// @} } // namespace Kaldi diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index d92f4fd3..79c896aa 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -18,9 +18,16 @@ namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; +using std::vector; +using kaldi::Vector; -Decodable::Decodable(const std::shared_ptr& nnet) - : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {} +Decodable::Decodable(const std::shared_ptr& nnet, + const std::shared_ptr& frontend) + : frontend_(frontend), + nnet_(nnet), + finished_(false), + frame_offset_(0), + frames_ready_(0) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { frames_ready_ += likelihood.NumRows(); @@ -31,26 +38,46 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { bool Decodable::IsLastFrame(int32 frame) const { CHECK_LE(frame, frames_ready_); - return finished_ && (frame == frames_ready_ - 1); + return IsInputFinished() && (frame == frames_ready_ - 1); } int32 Decodable::NumIndices() const { return 0; } -BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; } +BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { + CHECK_LE(index, nnet_cache_.NumCols()); + return 0; +} -void Decodable::FeedFeatures(const Matrix& features) { - nnet_->FeedForward(features, &nnet_cache_); +bool Decodable::EnsureFrameHaveComputed(int32 frame) { + if (frame >= frames_ready_) { + return AdvanceChunk(); + } + return true; +} + +bool Decodable::AdvanceChunk() { + Vector features; + if (frontend_->Read(&features) == false) { + return false; + } + int32 nnet_dim = 0; + Vector inferences; + nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); + nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); + nnet_cache_.CopyRowsFromVec(inferences); + frame_offset_ = frames_ready_; frames_ready_ += nnet_cache_.NumRows(); - return; + return true; } -std::vector Decodable::FrameLogLikelihood(int32 frame) { +bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { std::vector result; - result.reserve(nnet_cache_.NumCols()); + if (EnsureFrameHaveComputed(frame) == false) return false; + likelihood->resize(nnet_cache_.NumCols()); for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { - result[idx] = nnet_cache_(frame, idx); + (*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); } - return result; + return true; } void Decodable::Reset() { diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 5a59d6ab..72d194b9 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -24,25 +24,35 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet); + explicit Decodable( + const std::shared_ptr& nnet, + const std::shared_ptr& frontend); // void Init(DecodableOpts config); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame) const; virtual int32 NumIndices() const; - virtual std::vector FrameLogLikelihood(int32 frame); - void Acceptlikelihood( - const kaldi::Matrix& likelihood); // remove later - void FeedFeatures(const kaldi::Matrix& - feature); // only for test, todo remove later + virtual bool FrameLogLikelihood(int32 frame, + std::vector* likelihood); + // for offline test + void Acceptlikelihood(const kaldi::Matrix& likelihood); void Reset(); - void InputFinished() { finished_ = true; } + bool IsInputFinished() const { return frontend_->IsFinished(); } + bool EnsureFrameHaveComputed(int32 frame); private: + bool AdvanceChunk(); std::shared_ptr frontend_; std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; + // std::vector> nnet_cache_; bool finished_; + int32 frame_offset_; int32 frames_ready_; + // todo: feature frame mismatch with nnet inference frame + // eg: 35 frame features output 8 frame inferences + // so use subsampled_frame + int32 current_log_post_subsampled_offset_; + int32 num_chunk_computed_; }; } // namespace ppspeech diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index fe669f0a..ac040fba 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -23,8 +23,10 @@ namespace ppspeech { class NnetInterface { public: - virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences) = 0; + virtual void FeedForward(const kaldi::Vector& features, + int32 feature_dim, + kaldi::Vector* inferences, + int32* inference_dim) = 0; virtual void Reset() = 0; virtual ~NnetInterface() {} }; diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index 5dea4e51..c4b91cf6 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -21,6 +21,7 @@ using std::vector; using std::string; using std::shared_ptr; using kaldi::Matrix; +using kaldi::Vector; void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; @@ -143,34 +144,27 @@ shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { return cache_encouts_[iter->second]; } -void PaddleNnet::FeedForward(const Matrix& features, - Matrix* inferences) { +void PaddleNnet::FeedForward(const Vector& features, + int32 feature_dim, + Vector* inferences, + int32* inference_dim) { paddle_infer::Predictor* predictor = GetPredictor(); - int row = features.NumRows(); - int col = features.NumCols(); - std::vector feed_feature; - // todo refactor feed feature: SmileGoat - feed_feature.reserve(row * col); - for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { - for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { - feed_feature.push_back(features(row_idx, col_idx)); - } - } + int feat_row = features.Dim() / feature_dim; std::vector input_names = predictor->GetInputNames(); std::vector output_names = predictor->GetOutputNames(); - LOG(INFO) << "feat info: row=" << row << ", col= " << col; + LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim; std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, row, col}; + std::vector INPUT_SHAPE = {1, feat_row, feature_dim}; input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(feed_feature.data()); + input_tensor->CopyFromCpu(features.Data()); std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); std::vector input_len_size = {1}; input_len->Reshape(input_len_size); std::vector audio_len; - audio_len.push_back(row); + audio_len.push_back(feat_row); input_len->CopyFromCpu(audio_len.data()); std::unique_ptr h_box = @@ -203,20 +197,12 @@ void PaddleNnet::FeedForward(const Matrix& features, std::unique_ptr output_tensor = predictor->GetOutputHandle(output_names[0]); std::vector output_shape = output_tensor->shape(); - row = output_shape[1]; - col = output_shape[2]; - vector inferences_result; - inferences->Resize(row, col); - inferences_result.resize(row * col); - output_tensor->CopyToCpu(inferences_result.data()); + int32 row = output_shape[1]; + int32 col = output_shape[2]; + inferences->Resize(row * col); + *inference_dim = col; + output_tensor->CopyToCpu(inferences->Data()); ReleasePredictor(predictor); - - for (int row_idx = 0; row_idx < row; ++row_idx) { - for (int col_idx = 0; col_idx < col; ++col_idx) { - (*inferences)(row_idx, col_idx) = - inferences_result[col * row_idx + col_idx]; - } - } } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index aec27fd1..30fbac9f 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -39,12 +39,8 @@ struct ModelOptions { bool enable_fc_padding; bool enable_profile; ModelOptions() - : model_path( - "../../../../model/paddle_online_deepspeech/model/" - "avg_1.jit.pdmodel"), - params_path( - "../../../../model/paddle_online_deepspeech/model/" - "avg_1.jit.pdiparams"), + : model_path("avg_1.jit.pdmodel"), + params_path("avg_1.jit.pdiparams"), thread_num(2), use_gpu(false), input_names( @@ -107,8 +103,11 @@ class Tensor { class PaddleNnet : public NnetInterface { public: PaddleNnet(const ModelOptions& opts); - virtual void FeedForward(const kaldi::Matrix& features, - kaldi::Matrix* inferences); + virtual void FeedForward(const kaldi::Vector& features, + int32 feature_dim, + kaldi::Vector* inferences, + int32* inference_dim); + void Dim(); virtual void Reset(); std::shared_ptr> GetCacheEncoder( const std::string& name);