fix assembler buf, which not clear cache, and fill zero default

pull/2524/head
Hui Zhang 2 years ago
parent f9fc32e89e
commit 7e334ce890

@ -5,7 +5,7 @@ set -e
data=data data=data
exp=exp exp=exp
nj=20 nj=40
mkdir -p $exp mkdir -p $exp

@ -23,9 +23,11 @@ using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts, Assembler::Assembler(AssemblerOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
fill_zero_ = opts.fill_zero;
frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk; frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk;
frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate +
opts.receptive_filed_length; opts.receptive_filed_length;
cache_size_ = frame_chunk_size_ - frame_chunk_stride_;
receptive_filed_length_ = opts.receptive_filed_length; receptive_filed_length_ = opts.receptive_filed_length;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim(); dim_ = base_extractor_->Dim();
@ -38,14 +40,13 @@ void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// pop feature chunk // pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) { bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
feats->Resize(dim_ * frame_chunk_size_);
bool result = Compute(feats); bool result = Compute(feats);
return result; return result;
} }
// read all data from base_feature_extractor_ into cache_ // read frame by frame from base_feature_extractor_ into cache_
bool Assembler::Compute(Vector<BaseFloat>* feats) { bool Assembler::Compute(Vector<BaseFloat>* feats) {
// compute and feed // compute and feed frame by frame
bool result = false; bool result = false;
while (feature_cache_.size() < frame_chunk_size_) { while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature; Vector<BaseFloat> feature;
@ -54,33 +55,58 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
if (IsFinished() == false) return false; if (IsFinished() == false) return false;
break; break;
} }
CHECK(feature.Dim() == dim_);
nframes_ += 1;
VLOG(1) << "nframes: " << nframes_;
feature_cache_.push(feature); feature_cache_.push(feature);
} }
if (feature_cache_.size() < receptive_filed_length_) { if (feature_cache_.size() < receptive_filed_length_) {
VLOG(1) << "feature_cache less than receptive_filed_lenght. " << feature_cache_.size() << ": " << receptive_filed_length_;
return false; return false;
} }
while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero); if (fill_zero_){
feature_cache_.push(feature); while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
nframes_ += 1;
feature_cache_.push(feature);
}
} }
int32 this_chunk_size = std::min(static_cast<int32>(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
int32 counter = 0; int32 counter = 0;
int32 cache_size = frame_chunk_size_ - frame_chunk_stride_; while (counter < this_chunk_size) {
int32 elem_dim = base_extractor_->Dim();
while (counter < frame_chunk_size_) {
Vector<BaseFloat>& val = feature_cache_.front(); Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim; CHECK(val.Dim() == dim_) << val.Dim();
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size) { int32 start = counter * dim_;
feats->Range(start, dim_).CopyFromVec(val);
if (this_chunk_size - counter <= cache_size_) {
feature_cache_.push(val); feature_cache_.push(val);
} }
// val is reference, so we should pop here
feature_cache_.pop(); feature_cache_.pop();
counter++; counter++;
} }
return result; return result;
} }
void Assembler::Reset() {
std::queue<kaldi::Vector<kaldi::BaseFloat>> empty;
std::swap(feature_cache_, empty);
nframes_ = 0;
base_extractor_->Reset();
}
} // namespace ppspeech } // namespace ppspeech

@ -22,14 +22,10 @@ namespace ppspeech {
struct AssemblerOptions { struct AssemblerOptions {
// refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py // refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py
// the nnet batch forward // the nnet batch forward
int32 receptive_filed_length; int32 receptive_filed_length{1};
int32 subsampling_rate; int32 subsampling_rate{1};
int32 nnet_decoder_chunk; int32 nnet_decoder_chunk{1};
bool fill_zero{false}; // whether fill zero when last chunk is not equal to frame_chunk_size_
AssemblerOptions()
: receptive_filed_length(1),
subsampling_rate(1),
nnet_decoder_chunk(1) {}
}; };
class Assembler : public FrontendInterface { class Assembler : public FrontendInterface {
@ -39,29 +35,34 @@ class Assembler : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override;
// feats size = num_frames * feat_dim // feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override;
// feat dim // feat dim
virtual size_t Dim() const { return dim_; } size_t Dim() const override { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); } void SetFinished() override { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } bool IsFinished() const override { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); } void Reset() override;
private: private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats); bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);
int32 dim_; bool fill_zero_{false};
int32 dim_; // feat dim
int32 frame_chunk_size_; // window int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride int32 frame_chunk_stride_; // stride
int32 cache_size_; // window - stride
int32 receptive_filed_length_; int32 receptive_filed_length_;
std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_; std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
int32 nframes_; // num frame computed
DISALLOW_COPY_AND_ASSIGN(Assembler); DISALLOW_COPY_AND_ASSIGN(Assembler);
}; };

@ -83,6 +83,10 @@ bool AudioCache::Read(Vector<BaseFloat>* waves) {
} }
size_ -= chunk_size; size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size(); offset_ = (offset_ + chunk_size) % ring_buffer_.size();
nsamples_ += chunk_size;
VLOG(1) << "nsamples readed: " << nsamples_;
ready_feed_condition_.notify_one(); ready_feed_condition_.notify_one();
return true; return true;
} }

@ -41,10 +41,11 @@ class AudioCache : public FrontendInterface {
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
virtual void Reset() { void Reset() override {
offset_ = 0; offset_ = 0;
size_ = 0; size_ = 0;
finished_ = false; finished_ = false;
nsamples_ = 0;
} }
private: private:
@ -61,6 +62,7 @@ class AudioCache : public FrontendInterface {
kaldi::int32 timeout_; // millisecond kaldi::int32 timeout_; // millisecond
bool to_float32_; // int16 -> float32. used in linear_spectrogram bool to_float32_; // int16 -> float32. used in linear_spectrogram
int32 nsamples_; // number samples readed.
DISALLOW_COPY_AND_ASSIGN(AudioCache); DISALLOW_COPY_AND_ASSIGN(AudioCache);
}; };

@ -73,6 +73,9 @@ bool FeatureCache::Compute() {
if (result == false || feature.Dim() == 0) return false; if (result == false || feature.Dim() == 0) return false;
int32 num_chunk = feature.Dim() / dim_; int32 num_chunk = feature.Dim() / dim_;
nframe_ += num_chunk;
VLOG(1) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_; int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_); Vector<BaseFloat> feature_chunk(dim_);

@ -51,11 +51,12 @@ class FeatureCache : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { void Reset() override {
std::queue<kaldi::Vector<BaseFloat>> empty;
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset(); base_extractor_->Reset();
while (!cache_.empty()) { VLOG(1) << "feature cache reset: cache size: " << cache_.size();
cache_.pop();
}
} }
private: private:
@ -74,6 +75,7 @@ class FeatureCache : public FrontendInterface {
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_; std::condition_variable ready_read_condition_;
int32 nframe_; // num of feature computed
DISALLOW_COPY_AND_ASSIGN(FeatureCache); DISALLOW_COPY_AND_ASSIGN(FeatureCache);
}; };

@ -153,6 +153,7 @@ void U2Nnet::Reset() {
std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32)); std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32));
encoder_outs_.clear(); encoder_outs_.clear();
VLOG(1) << "u2nnet reset";
} }
// Debug API // Debug API

@ -82,9 +82,13 @@ int main(int argc, char* argv[]) {
// no overlap // no overlap
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
} }
CHECK(sample_offset == tot_samples);
// recognizer.SetFinished();
// second pass decoding // second pass decoding
recognizer.Rescoring(); recognizer.Rescoring();
std::string result = recognizer.GetFinalResult(); std::string result = recognizer.GetFinalResult();
recognizer.Reset(); recognizer.Reset();

Loading…
Cancel
Save