remove uesless code

pull/2524/head
Hui Zhang 3 years ago
parent 36af34b293
commit 0a8ef58af0

@ -25,65 +25,6 @@ using paddle::platform::TracerEventType;
namespace ppspeech {
int U2NnetBase::num_frames_for_chunk(bool start) const {
int num_needed_frames = 0; // num feat frames
bool first = !start; // start == false is first
if (chunk_size_ > 0) {
// streaming mode
if (first) {
// first chunk
// 1 decoder frame need `context` feat frames
int context = this->context();
num_needed_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
} else {
// after first chunk, we need stride this num frames.
num_needed_frames = chunk_size_ * subsampling_rate_;
}
} else {
// non-streaming mode. feed all feats once.
num_needed_frames = std::numeric_limits<int>::max();
}
return num_needed_frames;
}
// cache feats for next chunk
void U2NnetBase::CacheFeature(const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim) {
// chunk_feats is nframes*feat_dim
const int chunk_size = chunk_feats.size() / feat_dim;
const int cached_feat_size = this->context() - subsampling_rate_;
if (chunk_size >= cached_feat_size) {
cached_feats_.resize(cached_feat_size);
for (int i = 0; i < cached_feat_size; ++i) {
auto start =
chunk_feats.begin() + chunk_size - cached_feat_size + i;
auto end = start + feat_dim;
cached_feats_[i] = std::vector<float>(start, end);
}
}
}
void U2NnetBase::ForwardEncoderChunk(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) {
ctc_probs->clear();
// int num_frames = cached_feats_.size() + chunk_feats.size();
int num_frames = chunk_feats.size() / feat_dim;
VLOG(3) << "foward encoder chunk: " << num_frames << " frames";
VLOG(3) << "context: " << this->context() << " frames";
if (num_frames >= this->context()) {
this->ForwardEncoderChunkImpl(
chunk_feats, feat_dim, ctc_probs, vocab_dim);
VLOG(3) << "after forward chunk";
this->CacheFeature(chunk_feats, feat_dim);
}
}
void U2Nnet::LoadModel(const std::string& model_path_w_prefix) {
paddle::jit::utils::InitKernelSignatureMap();
@ -188,7 +129,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
forward_attention_decoder_ = other.forward_attention_decoder_;
ctc_activation_ = other.ctc_activation_;
// offset_ = other.offset_; // TODO: not used in nnets
offset_ = other.offset_;
// copy model ptr
model_ = other.model_;
@ -204,8 +145,7 @@ std::shared_ptr<NnetBase> U2Nnet::Copy() const {
}
void U2Nnet::Reset() {
// offset_ = 0;
// cached_feats_.clear(); // TODO: not used in nnets
offset_ = 0;
att_cache_ =
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
@ -263,16 +203,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
paddle::zeros({1, num_frames, feat_dim}, paddle::DataType::FLOAT32);
float* feats_ptr = feats.mutable_data<float>();
// for (size_t i = 0; i < cached_feats_.size(); ++i) {
// float* row = feats_ptr + i * feat_dim;
// std::memcpy(row, cached_feats_[i].data(), feat_dim * sizeof(float));
// }
// for (size_t i = 0; i < chunk_feats.size(); ++i) {
// float* row = feats_ptr + (cached_feats_.size() + i) * feat_dim;
// std::memcpy(row, chunk_feats[i].data(), feat_dim * sizeof(float));
// }
// not cache feature in nnet
CHECK(cached_feats_.size() == 0);
// CHECK_EQ(std::is_same<float, kaldi::BaseFloat>::value, true);
@ -427,15 +357,6 @@ void U2Nnet::ForwardEncoderChunkImpl(
float* ctc_log_probs_ptr = ctc_log_probs.data<float>();
// // vector<vector<float>>
// out_prob->resize(T);
// for (int i = 0; i < T; i++) {
// (*out_prob)[i].resize(D);
// float* dst_ptr = (*out_prob)[i].data();
// float* src_ptr = ctc_log_probs_ptr + (i * D);
// std::memcpy(dst_ptr, src_ptr, D * sizeof(float));
// }
// CHECK(std::is_same<float, kaldi::BaseFloat>::value);
out_prob->resize(T * D);
std::memcpy(
out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat));

@ -28,29 +28,21 @@ namespace ppspeech {
class U2NnetBase : public NnetBase {
public:
virtual int context() const { return right_context_ + 1; }
virtual int right_context() const { return right_context_; }
virtual int Context() const { return right_context_ + 1; }
virtual int RightContext() const { return right_context_; }
virtual int eos() const { return eos_; }
virtual int sos() const { return sos_; }
virtual int is_bidecoder() const { return is_bidecoder_; }
virtual int EOS() const { return eos_; }
virtual int SOS() const { return sos_; }
virtual int IsBidecoder() const { return is_bidecoder_; }
// current offset in decoder frame
virtual int offset() const { return offset_; }
virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }
virtual void set_num_left_chunks(int num_left_chunks) {
virtual int Offset() const { return offset_; }
virtual void SetChunkSize(int chunk_size) { chunk_size_ = chunk_size; }
virtual void SetNumLeftChunks(int num_left_chunks) {
num_left_chunks_ = num_left_chunks;
}
// start: false, it is the start chunk of one sentence, else true
virtual int num_frames_for_chunk(bool start) const;
virtual std::shared_ptr<NnetBase> Copy() const = 0;
virtual void ForwardEncoderChunk(
const std::vector<kaldi::BaseFloat>& chunk_feats,
const int32& feat_dim,
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim);
protected:
virtual void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,
@ -58,9 +50,6 @@ class U2NnetBase : public NnetBase {
std::vector<kaldi::BaseFloat>* ctc_probs,
int32* vocab_dim) = 0;
virtual void CacheFeature(const std::vector<kaldi::BaseFloat>& chunk_feats,
int32 feat_dim);
protected:
// model specification
int right_context_{0};

Loading…
Cancel
Save