From 0a8ef58af088d58ae882044640eba5fcb64ccf13 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 18 Oct 2022 09:10:49 +0000 Subject: [PATCH] remove uesless code --- speechx/speechx/nnet/u2_nnet.cc | 83 +-------------------------------- speechx/speechx/nnet/u2_nnet.h | 27 ++++------- 2 files changed, 10 insertions(+), 100 deletions(-) diff --git a/speechx/speechx/nnet/u2_nnet.cc b/speechx/speechx/nnet/u2_nnet.cc index ff6a4dc37..baae2ce8f 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/speechx/speechx/nnet/u2_nnet.cc @@ -25,65 +25,6 @@ using paddle::platform::TracerEventType; namespace ppspeech { -int U2NnetBase::num_frames_for_chunk(bool start) const { - int num_needed_frames = 0; // num feat frames - bool first = !start; // start == false is first - - if (chunk_size_ > 0) { - // streaming mode - if (first) { - // first chunk - // 1 decoder frame need `context` feat frames - int context = this->context(); - num_needed_frames = (chunk_size_ - 1) * subsampling_rate_ + context; - } else { - // after first chunk, we need stride this num frames. - num_needed_frames = chunk_size_ * subsampling_rate_; - } - } else { - // non-streaming mode. feed all feats once. - num_needed_frames = std::numeric_limits::max(); - } - - return num_needed_frames; -} - -// cache feats for next chunk -void U2NnetBase::CacheFeature(const std::vector& chunk_feats, - int32 feat_dim) { - // chunk_feats is nframes*feat_dim - const int chunk_size = chunk_feats.size() / feat_dim; - const int cached_feat_size = this->context() - subsampling_rate_; - if (chunk_size >= cached_feat_size) { - cached_feats_.resize(cached_feat_size); - for (int i = 0; i < cached_feat_size; ++i) { - auto start = - chunk_feats.begin() + chunk_size - cached_feat_size + i; - auto end = start + feat_dim; - cached_feats_[i] = std::vector(start, end); - } - } -} - -void U2NnetBase::ForwardEncoderChunk( - const std::vector& chunk_feats, - const int32& feat_dim, - std::vector* ctc_probs, - int32* vocab_dim) { - ctc_probs->clear(); - // int num_frames = cached_feats_.size() + chunk_feats.size(); - int num_frames = chunk_feats.size() / feat_dim; - VLOG(3) << "foward encoder chunk: " << num_frames << " frames"; - VLOG(3) << "context: " << this->context() << " frames"; - - if (num_frames >= this->context()) { - this->ForwardEncoderChunkImpl( - chunk_feats, feat_dim, ctc_probs, vocab_dim); - VLOG(3) << "after forward chunk"; - this->CacheFeature(chunk_feats, feat_dim); - } -} - void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { paddle::jit::utils::InitKernelSignatureMap(); @@ -188,7 +129,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) { forward_attention_decoder_ = other.forward_attention_decoder_; ctc_activation_ = other.ctc_activation_; - // offset_ = other.offset_; // TODO: not used in nnets + offset_ = other.offset_; // copy model ptr model_ = other.model_; @@ -204,8 +145,7 @@ std::shared_ptr U2Nnet::Copy() const { } void U2Nnet::Reset() { - // offset_ = 0; - // cached_feats_.clear(); // TODO: not used in nnets + offset_ = 0; att_cache_ = std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32)); @@ -263,16 +203,6 @@ void U2Nnet::ForwardEncoderChunkImpl( paddle::zeros({1, num_frames, feat_dim}, paddle::DataType::FLOAT32); float* feats_ptr = feats.mutable_data(); - // for (size_t i = 0; i < cached_feats_.size(); ++i) { - // float* row = feats_ptr + i * feat_dim; - // std::memcpy(row, cached_feats_[i].data(), feat_dim * sizeof(float)); - // } - - // for (size_t i = 0; i < chunk_feats.size(); ++i) { - // float* row = feats_ptr + (cached_feats_.size() + i) * feat_dim; - // std::memcpy(row, chunk_feats[i].data(), feat_dim * sizeof(float)); - // } - // not cache feature in nnet CHECK(cached_feats_.size() == 0); // CHECK_EQ(std::is_same::value, true); @@ -427,15 +357,6 @@ void U2Nnet::ForwardEncoderChunkImpl( float* ctc_log_probs_ptr = ctc_log_probs.data(); - // // vector> - // out_prob->resize(T); - // for (int i = 0; i < T; i++) { - // (*out_prob)[i].resize(D); - // float* dst_ptr = (*out_prob)[i].data(); - // float* src_ptr = ctc_log_probs_ptr + (i * D); - // std::memcpy(dst_ptr, src_ptr, D * sizeof(float)); - // } - // CHECK(std::is_same::value); out_prob->resize(T * D); std::memcpy( out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat)); diff --git a/speechx/speechx/nnet/u2_nnet.h b/speechx/speechx/nnet/u2_nnet.h index 48dd8193b..6cbc05706 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/speechx/speechx/nnet/u2_nnet.h @@ -28,29 +28,21 @@ namespace ppspeech { class U2NnetBase : public NnetBase { public: - virtual int context() const { return right_context_ + 1; } - virtual int right_context() const { return right_context_; } + virtual int Context() const { return right_context_ + 1; } + virtual int RightContext() const { return right_context_; } - virtual int eos() const { return eos_; } - virtual int sos() const { return sos_; } - virtual int is_bidecoder() const { return is_bidecoder_; } + virtual int EOS() const { return eos_; } + virtual int SOS() const { return sos_; } + virtual int IsBidecoder() const { return is_bidecoder_; } // current offset in decoder frame - virtual int offset() const { return offset_; } - virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; } - virtual void set_num_left_chunks(int num_left_chunks) { + virtual int Offset() const { return offset_; } + virtual void SetChunkSize(int chunk_size) { chunk_size_ = chunk_size; } + virtual void SetNumLeftChunks(int num_left_chunks) { num_left_chunks_ = num_left_chunks; } - // start: false, it is the start chunk of one sentence, else true - virtual int num_frames_for_chunk(bool start) const; virtual std::shared_ptr Copy() const = 0; - virtual void ForwardEncoderChunk( - const std::vector& chunk_feats, - const int32& feat_dim, - std::vector* ctc_probs, - int32* vocab_dim); - protected: virtual void ForwardEncoderChunkImpl( const std::vector& chunk_feats, @@ -58,9 +50,6 @@ class U2NnetBase : public NnetBase { std::vector* ctc_probs, int32* vocab_dim) = 0; - virtual void CacheFeature(const std::vector& chunk_feats, - int32 feat_dim); - protected: // model specification int right_context_{0};