From 5383dff250e4bc43113366ea4e8df166be1cba77 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Mon, 14 Mar 2022 12:19:39 +0800 Subject: [PATCH] Revert "align nnet decoder & refactor" --- speechx/examples/decoder/CMakeLists.txt | 6 +- ...ecoder_main.cc => offline-decoder-main.cc} | 55 +++----- .../decoder/ctc_beam_search_decoder.cc | 18 +-- .../speechx/decoder/ctc_beam_search_decoder.h | 10 +- speechx/speechx/frontend/raw_audio.h | 16 ++- speechx/speechx/nnet/decodable-itf.h | 117 +++++++++++------- speechx/speechx/nnet/decodable.cc | 49 ++------ speechx/speechx/nnet/decodable.h | 24 ++-- speechx/speechx/nnet/nnet_interface.h | 6 +- speechx/speechx/nnet/paddle_nnet.cc | 44 ++++--- speechx/speechx/nnet/paddle_nnet.h | 15 +-- 11 files changed, 170 insertions(+), 190 deletions(-) rename speechx/examples/decoder/{offline_decoder_main.cc => offline-decoder-main.cc} (52%) diff --git a/speechx/examples/decoder/CMakeLists.txt b/speechx/examples/decoder/CMakeLists.txt index 4bd5c6cf..cf90d094 100644 --- a/speechx/examples/decoder/CMakeLists.txt +++ b/speechx/examples/decoder/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) -add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc) -target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(offline-decoder-main ${CMAKE_CURRENT_SOURCE_DIR}/offline-decoder-main.cc) +target_include_directories(offline-decoder-main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(offline-decoder-main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) \ No newline at end of file diff --git a/speechx/examples/decoder/offline_decoder_main.cc b/speechx/examples/decoder/offline-decoder-main.cc similarity index 52% rename from speechx/examples/decoder/offline_decoder_main.cc rename to speechx/examples/decoder/offline-decoder-main.cc index cffca39a..8e6e7850 100644 --- a/speechx/examples/decoder/offline_decoder_main.cc +++ b/speechx/examples/decoder/offline-decoder-main.cc @@ -17,75 +17,50 @@ #include "base/flags.h" #include "base/log.h" #include "decoder/ctc_beam_search_decoder.h" -#include "frontend/raw_audio.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" -DEFINE_string(feature_respecifier, "", "test feature rspecifier"); -DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); - +DEFINE_string(feature_respecifier, "", "test nnet prob"); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; +// void SplitFeature(kaldi::Matrix feature, +// int32 chunk_size, +// std::vector* feature_chunks) { + +//} + int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_respecifier); - std::string model_graph = FLAGS_model_path; - std::string model_params = FLAGS_param_path; - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; + // test nnet_output --> decoder result int32 num_done = 0, num_err = 0; ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; - model_opts.params_path = model_params; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); - std::shared_ptr raw_data( - new ppspeech::RawDataCache()); + std::shared_ptr decodable( - new ppspeech::Decodable(nnet, raw_data)); + new ppspeech::Decodable(nnet)); - int32 chunk_size = 35; + // int32 chunk_size = 35; decoder.InitDecoder(); for (; !feature_reader.Done(); feature_reader.Next()) { string utt = feature_reader.Key(); const kaldi::Matrix feature = feature_reader.Value(); - raw_data->SetDim(feature.NumCols()); - int32 row_idx = 0; - int32 num_chunks = feature.NumRows() / chunk_size; - for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - kaldi::Vector feature_chunk(chunk_size * - feature.NumCols()); - for (int row_id = 0; row_id < chunk_size; ++row_id) { - kaldi::SubVector tmp(feature, row_idx); - kaldi::SubVector f_chunk_tmp( - feature_chunk.Data() + row_id * feature.NumCols(), - feature.NumCols()); - f_chunk_tmp.CopyFromVec(tmp); - row_idx++; - } - raw_data->Accept(feature_chunk); - if (chunk_idx == num_chunks - 1) { - raw_data->SetFinished(); - } - decoder.AdvanceDecode(decodable); - } + decodable->FeedFeatures(feature); + decoder.AdvanceDecode(decodable, 8); + decodable->InputFinished(); std::string result; result = decoder.GetFinalBestPath(); KALDI_LOG << " the result of " << utt << " is " << result; @@ -96,4 +71,4 @@ int main(int argc, char* argv[]) { KALDI_LOG << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); -} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.cc b/speechx/speechx/decoder/ctc_beam_search_decoder.cc index 8106b710..92c57858 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.cc +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.cc @@ -79,19 +79,21 @@ void CTCBeamSearch::Decode( return; } -int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; } +int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_; } // todo rename, refactor void CTCBeamSearch::AdvanceDecode( - const std::shared_ptr& decodable) { - while (1) { + const std::shared_ptr& decodable, + int max_frames) { + while (max_frames > 0) { vector> likelihood; - vector frame_prob; - bool flag = - decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob); - if (flag == false) break; - likelihood.push_back(frame_prob); + if (decodable->IsLastFrame(NumFrameDecoded() + 1)) { + break; + } + likelihood.push_back( + decodable->FrameLogLikelihood(NumFrameDecoded() + 1)); AdvanceDecoding(likelihood); + max_frames--; } } diff --git a/speechx/speechx/decoder/ctc_beam_search_decoder.h b/speechx/speechx/decoder/ctc_beam_search_decoder.h index 53700e27..1e6ac88b 100644 --- a/speechx/speechx/decoder/ctc_beam_search_decoder.h +++ b/speechx/speechx/decoder/ctc_beam_search_decoder.h @@ -32,8 +32,8 @@ struct CTCBeamSearchOptions { int cutoff_top_n; int num_proc_bsearch; CTCBeamSearchOptions() - : dict_file("vocab.txt"), - lm_path("lm.klm"), + : dict_file("./model/words.txt"), + lm_path("./model/lm.arpa"), alpha(1.9f), beta(5.0), beam_size(300), @@ -68,7 +68,8 @@ class CTCBeamSearch { int DecodeLikelihoods(const std::vector>& probs, std::vector& nbest_words); void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable, + int max_frames); void Reset(); private: @@ -82,7 +83,8 @@ class CTCBeamSearch { CTCBeamSearchOptions opts_; std::shared_ptr init_ext_scorer_; // todo separate later - std::vector vocabulary_; // todo remove later + // std::vector decoder_results_; + std::vector vocabulary_; // todo remove later size_t blank_id; int space_id; std::shared_ptr root; diff --git a/speechx/speechx/frontend/raw_audio.h b/speechx/speechx/frontend/raw_audio.h index 1a326b3c..996b6e78 100644 --- a/speechx/speechx/frontend/raw_audio.h +++ b/speechx/speechx/frontend/raw_audio.h @@ -18,8 +18,6 @@ #include "base/common.h" #include "frontend/feature_extractor_interface.h" -#pragma once - namespace ppspeech { class RawAudioCache : public FeatureExtractorInterface { @@ -47,12 +45,13 @@ class RawAudioCache : public FeatureExtractorInterface { DISALLOW_COPY_AND_ASSIGN(RawAudioCache); }; -// it is a datasource for testing different frontend module. -// it accepts waves or feats. -class RawDataCache : public FeatureExtractorInterface { +// it is a data source to test different frontend module. +// it Accepts waves or feats. +class RawDataCache: public FeatureExtractorInterface { public: explicit RawDataCache() { finished_ = false; } - virtual void Accept(const kaldi::VectorBase& inputs) { + virtual void Accept( + const kaldi::VectorBase& inputs) { data_ = inputs; } virtual bool Read(kaldi::Vector* feats) { @@ -63,15 +62,14 @@ class RawDataCache : public FeatureExtractorInterface { data_.Resize(0); return true; } - virtual size_t Dim() const { return dim_; } + //the dim is data_ length + virtual size_t Dim() const { return data_.Dim(); } virtual void SetFinished() { finished_ = true; } virtual bool IsFinished() const { return finished_; } - void SetDim(int32 dim) { dim_ = dim; } private: kaldi::Vector data_; bool finished_; - int32 dim_; DISALLOW_COPY_AND_ASSIGN(RawDataCache); }; diff --git a/speechx/speechx/nnet/decodable-itf.h b/speechx/speechx/nnet/decodable-itf.h index 37c3007b..3ea9b557 100644 --- a/speechx/speechx/nnet/decodable-itf.h +++ b/speechx/speechx/nnet/decodable-itf.h @@ -1,3 +1,17 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // itf/decodable-itf.h // Copyright 2009-2011 Microsoft Corporation; Saarland University; @@ -42,8 +56,10 @@ namespace kaldi { For online decoding, where the features are coming in in real time, it is important to understand the IsLastFrame() and NumFramesReady() functions. - There are two ways these are used: the old online-decoding code, in ../online/, - and the new online-decoding code, in ../online2/. In the old online-decoding + There are two ways these are used: the old online-decoding code, in + ../online/, + and the new online-decoding code, in ../online2/. In the old + online-decoding code, the decoder would do: \code{.cc} for (int frame = 0; !decodable.IsLastFrame(frame); frame++) { @@ -52,13 +68,16 @@ namespace kaldi { \endcode and the call to IsLastFrame would block if the features had not arrived yet. The decodable object would have to know when to terminate the decoding. This - online-decoding mode is still supported, it is what happens when you call, for + online-decoding mode is still supported, it is what happens when you call, + for example, LatticeFasterDecoder::Decode(). We realized that this "blocking" mode of decoding is not very convenient because it forces the program to be multi-threaded and makes it complex to - control endpointing. In the "new" decoding code, you don't call (for example) - LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(), + control endpointing. In the "new" decoding code, you don't call (for + example) + LatticeFasterDecoder::Decode(), you call + LatticeFasterDecoder::InitDecoding(), and then each time you get more features, you provide them to the decodable object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does something like this: @@ -68,7 +87,8 @@ namespace kaldi { } \endcode So the decodable object never has IsLastFrame() called. For decoding where - you are starting with a matrix of features, the NumFramesReady() function will + you are starting with a matrix of features, the NumFramesReady() function + will always just return the number of frames in the file, and IsLastFrame() will return true for the last frame. @@ -80,45 +100,52 @@ namespace kaldi { frame of the file once we've decided to terminate decoding. */ class DecodableInterface { - public: - /// Returns the log likelihood, which will be negated in the decoder. - /// The "frame" starts from zero. You should verify that NumFramesReady() > frame - /// before calling this. - virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; - - /// Returns true if this is the last frame. Frames are zero-based, so the - /// first frame is zero. IsLastFrame(-1) will return false, unless the file - /// is empty (which is a case that I'm not sure all the code will handle, so - /// be careful). Caution: the behavior of this function in an online setting - /// is being changed somewhat. In future it may return false in cases where - /// we haven't yet decided to terminate decoding, but later true if we decide - /// to terminate decoding. The plan in future is to rely more on - /// NumFramesReady(), and in future, IsLastFrame() would always return false - /// in an online-decoding setting, and would only return true in a - /// decoding-from-matrix setting where we want to allow the last delta or LDA - /// features to be flushed out for compatibility with the baseline setup. - virtual bool IsLastFrame(int32 frame) const = 0; - - /// The call NumFramesReady() will return the number of frames currently available - /// for this decodable object. This is for use in setups where you don't want the - /// decoder to block while waiting for input. This is newly added as of Jan 2014, - /// and I hope, going forward, to rely on this mechanism more than IsLastFrame to - /// know when to stop decoding. - virtual int32 NumFramesReady() const { - KALDI_ERR << "NumFramesReady() not implemented for this decodable type."; - return -1; - } - - /// Returns the number of states in the acoustic model - /// (they will be indexed one-based, i.e. from 1 to NumIndices(); - /// this is for compatibility with OpenFst). - virtual int32 NumIndices() const = 0; - - virtual bool FrameLogLikelihood(int32 frame, - std::vector* likelihood) = 0; - - - virtual ~DecodableInterface() {} + public: + /// Returns the log likelihood, which will be negated in the decoder. + /// The "frame" starts from zero. You should verify that NumFramesReady() > + /// frame + /// before calling this. + virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0; + + /// Returns true if this is the last frame. Frames are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). Caution: the behavior of this function in an online + /// setting + /// is being changed somewhat. In future it may return false in cases where + /// we haven't yet decided to terminate decoding, but later true if we + /// decide + /// to terminate decoding. The plan in future is to rely more on + /// NumFramesReady(), and in future, IsLastFrame() would always return false + /// in an online-decoding setting, and would only return true in a + /// decoding-from-matrix setting where we want to allow the last delta or + /// LDA + /// features to be flushed out for compatibility with the baseline setup. + virtual bool IsLastFrame(int32 frame) const = 0; + + /// The call NumFramesReady() will return the number of frames currently + /// available + /// for this decodable object. This is for use in setups where you don't + /// want the + /// decoder to block while waiting for input. This is newly added as of Jan + /// 2014, + /// and I hope, going forward, to rely on this mechanism more than + /// IsLastFrame to + /// know when to stop decoding. + virtual int32 NumFramesReady() const { + KALDI_ERR + << "NumFramesReady() not implemented for this decodable type."; + return -1; + } + + /// Returns the number of states in the acoustic model + /// (they will be indexed one-based, i.e. from 1 to NumIndices(); + /// this is for compatibility with OpenFst). + virtual int32 NumIndices() const = 0; + + virtual std::vector FrameLogLikelihood(int32 frame) = 0; + + virtual ~DecodableInterface() {} }; /// @} } // namespace Kaldi diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 79c896aa..d92f4fd3 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -18,16 +18,9 @@ namespace ppspeech { using kaldi::BaseFloat; using kaldi::Matrix; -using std::vector; -using kaldi::Vector; -Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend) - : frontend_(frontend), - nnet_(nnet), - finished_(false), - frame_offset_(0), - frames_ready_(0) {} +Decodable::Decodable(const std::shared_ptr& nnet) + : frontend_(NULL), nnet_(nnet), finished_(false), frames_ready_(0) {} void Decodable::Acceptlikelihood(const Matrix& likelihood) { frames_ready_ += likelihood.NumRows(); @@ -38,46 +31,26 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { bool Decodable::IsLastFrame(int32 frame) const { CHECK_LE(frame, frames_ready_); - return IsInputFinished() && (frame == frames_ready_ - 1); + return finished_ && (frame == frames_ready_ - 1); } int32 Decodable::NumIndices() const { return 0; } -BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { - CHECK_LE(index, nnet_cache_.NumCols()); - return 0; -} +BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return 0; } -bool Decodable::EnsureFrameHaveComputed(int32 frame) { - if (frame >= frames_ready_) { - return AdvanceChunk(); - } - return true; -} - -bool Decodable::AdvanceChunk() { - Vector features; - if (frontend_->Read(&features) == false) { - return false; - } - int32 nnet_dim = 0; - Vector inferences; - nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); - nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); - nnet_cache_.CopyRowsFromVec(inferences); - frame_offset_ = frames_ready_; +void Decodable::FeedFeatures(const Matrix& features) { + nnet_->FeedForward(features, &nnet_cache_); frames_ready_ += nnet_cache_.NumRows(); - return true; + return; } -bool Decodable::FrameLogLikelihood(int32 frame, vector* likelihood) { +std::vector Decodable::FrameLogLikelihood(int32 frame) { std::vector result; - if (EnsureFrameHaveComputed(frame) == false) return false; - likelihood->resize(nnet_cache_.NumCols()); + result.reserve(nnet_cache_.NumCols()); for (int32 idx = 0; idx < nnet_cache_.NumCols(); ++idx) { - (*likelihood)[idx] = nnet_cache_(frame - frame_offset_, idx); + result[idx] = nnet_cache_(frame, idx); } - return true; + return result; } void Decodable::Reset() { diff --git a/speechx/speechx/nnet/decodable.h b/speechx/speechx/nnet/decodable.h index 72d194b9..5a59d6ab 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/speechx/speechx/nnet/decodable.h @@ -24,35 +24,25 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable( - const std::shared_ptr& nnet, - const std::shared_ptr& frontend); + explicit Decodable(const std::shared_ptr& nnet); // void Init(DecodableOpts config); virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); virtual bool IsLastFrame(int32 frame) const; virtual int32 NumIndices() const; - virtual bool FrameLogLikelihood(int32 frame, - std::vector* likelihood); - // for offline test - void Acceptlikelihood(const kaldi::Matrix& likelihood); + virtual std::vector FrameLogLikelihood(int32 frame); + void Acceptlikelihood( + const kaldi::Matrix& likelihood); // remove later + void FeedFeatures(const kaldi::Matrix& + feature); // only for test, todo remove later void Reset(); - bool IsInputFinished() const { return frontend_->IsFinished(); } - bool EnsureFrameHaveComputed(int32 frame); + void InputFinished() { finished_ = true; } private: - bool AdvanceChunk(); std::shared_ptr frontend_; std::shared_ptr nnet_; kaldi::Matrix nnet_cache_; - // std::vector> nnet_cache_; bool finished_; - int32 frame_offset_; int32 frames_ready_; - // todo: feature frame mismatch with nnet inference frame - // eg: 35 frame features output 8 frame inferences - // so use subsampled_frame - int32 current_log_post_subsampled_offset_; - int32 num_chunk_computed_; }; } // namespace ppspeech diff --git a/speechx/speechx/nnet/nnet_interface.h b/speechx/speechx/nnet/nnet_interface.h index ac040fba..fe669f0a 100644 --- a/speechx/speechx/nnet/nnet_interface.h +++ b/speechx/speechx/nnet/nnet_interface.h @@ -23,10 +23,8 @@ namespace ppspeech { class NnetInterface { public: - virtual void FeedForward(const kaldi::Vector& features, - int32 feature_dim, - kaldi::Vector* inferences, - int32* inference_dim) = 0; + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Matrix* inferences) = 0; virtual void Reset() = 0; virtual ~NnetInterface() {} }; diff --git a/speechx/speechx/nnet/paddle_nnet.cc b/speechx/speechx/nnet/paddle_nnet.cc index c4b91cf6..5dea4e51 100644 --- a/speechx/speechx/nnet/paddle_nnet.cc +++ b/speechx/speechx/nnet/paddle_nnet.cc @@ -21,7 +21,6 @@ using std::vector; using std::string; using std::shared_ptr; using kaldi::Matrix; -using kaldi::Vector; void PaddleNnet::InitCacheEncouts(const ModelOptions& opts) { std::vector cache_names; @@ -144,27 +143,34 @@ shared_ptr> PaddleNnet::GetCacheEncoder(const string& name) { return cache_encouts_[iter->second]; } -void PaddleNnet::FeedForward(const Vector& features, - int32 feature_dim, - Vector* inferences, - int32* inference_dim) { +void PaddleNnet::FeedForward(const Matrix& features, + Matrix* inferences) { paddle_infer::Predictor* predictor = GetPredictor(); - int feat_row = features.Dim() / feature_dim; + int row = features.NumRows(); + int col = features.NumCols(); + std::vector feed_feature; + // todo refactor feed feature: SmileGoat + feed_feature.reserve(row * col); + for (size_t row_idx = 0; row_idx < features.NumRows(); ++row_idx) { + for (size_t col_idx = 0; col_idx < features.NumCols(); ++col_idx) { + feed_feature.push_back(features(row_idx, col_idx)); + } + } std::vector input_names = predictor->GetInputNames(); std::vector output_names = predictor->GetOutputNames(); - LOG(INFO) << "feat info: rows, cols: " << feat_row << ", " << feature_dim; + LOG(INFO) << "feat info: row=" << row << ", col= " << col; std::unique_ptr input_tensor = predictor->GetInputHandle(input_names[0]); - std::vector INPUT_SHAPE = {1, feat_row, feature_dim}; + std::vector INPUT_SHAPE = {1, row, col}; input_tensor->Reshape(INPUT_SHAPE); - input_tensor->CopyFromCpu(features.Data()); + input_tensor->CopyFromCpu(feed_feature.data()); std::unique_ptr input_len = predictor->GetInputHandle(input_names[1]); std::vector input_len_size = {1}; input_len->Reshape(input_len_size); std::vector audio_len; - audio_len.push_back(feat_row); + audio_len.push_back(row); input_len->CopyFromCpu(audio_len.data()); std::unique_ptr h_box = @@ -197,12 +203,20 @@ void PaddleNnet::FeedForward(const Vector& features, std::unique_ptr output_tensor = predictor->GetOutputHandle(output_names[0]); std::vector output_shape = output_tensor->shape(); - int32 row = output_shape[1]; - int32 col = output_shape[2]; - inferences->Resize(row * col); - *inference_dim = col; - output_tensor->CopyToCpu(inferences->Data()); + row = output_shape[1]; + col = output_shape[2]; + vector inferences_result; + inferences->Resize(row, col); + inferences_result.resize(row * col); + output_tensor->CopyToCpu(inferences_result.data()); ReleasePredictor(predictor); + + for (int row_idx = 0; row_idx < row; ++row_idx) { + for (int col_idx = 0; col_idx < col; ++col_idx) { + (*inferences)(row_idx, col_idx) = + inferences_result[col * row_idx + col_idx]; + } + } } } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/paddle_nnet.h b/speechx/speechx/nnet/paddle_nnet.h index 30fbac9f..aec27fd1 100644 --- a/speechx/speechx/nnet/paddle_nnet.h +++ b/speechx/speechx/nnet/paddle_nnet.h @@ -39,8 +39,12 @@ struct ModelOptions { bool enable_fc_padding; bool enable_profile; ModelOptions() - : model_path("avg_1.jit.pdmodel"), - params_path("avg_1.jit.pdiparams"), + : model_path( + "../../../../model/paddle_online_deepspeech/model/" + "avg_1.jit.pdmodel"), + params_path( + "../../../../model/paddle_online_deepspeech/model/" + "avg_1.jit.pdiparams"), thread_num(2), use_gpu(false), input_names( @@ -103,11 +107,8 @@ class Tensor { class PaddleNnet : public NnetInterface { public: PaddleNnet(const ModelOptions& opts); - virtual void FeedForward(const kaldi::Vector& features, - int32 feature_dim, - kaldi::Vector* inferences, - int32* inference_dim); - void Dim(); + virtual void FeedForward(const kaldi::Matrix& features, + kaldi::Matrix* inferences); virtual void Reset(); std::shared_ptr> GetCacheEncoder( const std::string& name);