refactor cache

pull/1638/head
Hui Zhang 2 years ago
parent 3572cacfd3
commit 760e5d4418

@ -17,7 +17,7 @@
#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/raw_audio.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
@ -60,8 +60,7 @@ int main(int argc, char* argv[]) {
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
LOG(INFO) << "Init decodeable.";

@ -17,10 +17,11 @@
#include "frontend/linear_spectrogram.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio_cache.h"
#include "frontend/data_cache.h"
#include "frontend/feature_cache.h"
#include "frontend/feature_extractor_interface.h"
#include "frontend/normalizer.h"
#include "frontend/raw_audio.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
@ -170,9 +171,9 @@ int main(int argc, char* argv[]) {
// window -->linear_spectrogram --> global cmvn -> feat cache
// std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(new
// ppspeech::RawDataCache());
// ppspeech::DataCache());
std::unique_ptr<ppspeech::FeatureExtractorInterface> data_source(
new ppspeech::RawAudioCache());
new ppspeech::AudioCache());
ppspeech::DecibelNormalizerOptions db_norm_opt;
std::unique_ptr<ppspeech::FeatureExtractorInterface> db_norm(

@ -3,8 +3,8 @@ project(frontend)
add_library(frontend STATIC
normalizer.cc
linear_spectrogram.cc
raw_audio.cc
audio_cache.cc
feature_cache.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix)

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/raw_audio.h"
#include "frontend/audio_cache.h"
#include "kaldi/base/timer.h"
namespace ppspeech {
@ -21,38 +21,43 @@ using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector;
RawAudioCache::RawAudioCache(int buffer_size)
: finished_(false), data_length_(0), start_(0), timeout_(1) {
ring_buffer_.resize(buffer_size);
AudioCache::AudioCache(int buffer_size)
: finished_(false),
capacity_(buffer_size),
size_(0),
offset_(0),
timeout_(1) {
ring_buffer_.resize(capacity_);
}
void RawAudioCache::Accept(const VectorBase<BaseFloat>& waves) {
void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
std::unique_lock<std::mutex> lock(mutex_);
while (data_length_ + waves.Dim() > ring_buffer_.size()) {
while (size_ + waves.Dim() > ring_buffer_.size()) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + start_) % ring_buffer_.size();
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
}
data_length_ += waves.Dim();
size_ += waves.Dim();
}
bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
bool AudioCache::Read(Vector<BaseFloat>* waves) {
size_t chunk_size = waves->Dim();
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (chunk_size > data_length_) {
while (chunk_size > size_) {
// when audio is empty and no more data feed
// ready_read_condition will block in dead lock. so replace with
// timeout_
// ready_read_condition will block in dead lock,
// so replace with timeout_
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > timeout_) {
if (finished_ == true) { // read last chunk data
if (finished_ == true) {
// read last chunk data
break;
}
if (chunk_size > data_length_) {
if (chunk_size > size_) {
return false;
}
}
@ -60,17 +65,17 @@ bool RawAudioCache::Read(Vector<BaseFloat>* waves) {
}
// read last chunk data
if (chunk_size > data_length_) {
chunk_size = data_length_;
if (chunk_size > size_) {
chunk_size = size_;
waves->Resize(chunk_size);
}
for (size_t idx = 0; idx < chunk_size; ++idx) {
int buff_idx = (start_ + idx) % ring_buffer_.size();
int buff_idx = (offset_ + idx) % ring_buffer_.size();
waves->Data()[idx] = ring_buffer_[buff_idx];
}
data_length_ -= chunk_size;
start_ = (start_ + chunk_size) % ring_buffer_.size();
size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size();
ready_feed_condition_.notify_one();
return true;
}

@ -0,0 +1,61 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
namespace ppspeech {
// waves cache
class AudioCache : public FeatureExtractorInterface {
public:
explicit AudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1, one sample
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
offset_ = 0;
size_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t offset_; // offset in ring_buffer_
size_t size_; // samples in ring_buffer_ now
size_t capacity_; // capacity of ring_buffer_
bool finished_; // reach audio end
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_; // millisecond
DISALLOW_COPY_AND_ASSIGN(AudioCache);
};
} // namespace ppspeech

@ -15,51 +15,22 @@
#pragma once
#include "base/common.h"
#include "frontend/feature_extractor_interface.h"
#pragma once
namespace ppspeech {
class RawAudioCache : public FeatureExtractorInterface {
// A data source for testing different frontend module.
// It accepts waves or feats.
class DataCache : public FeatureExtractorInterface {
public:
explicit RawAudioCache(int buffer_size = kint16max);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1
virtual size_t Dim() const { return 1; }
virtual void SetFinished() {
std::lock_guard<std::mutex> lock(mutex_);
finished_ = true;
}
virtual bool IsFinished() const { return finished_; }
virtual void Reset() {
start_ = 0;
data_length_ = 0;
finished_ = false;
}
private:
std::vector<kaldi::BaseFloat> ring_buffer_;
size_t start_;
size_t data_length_;
bool finished_;
mutable std::mutex mutex_;
std::condition_variable ready_feed_condition_;
kaldi::int32 timeout_;
DISALLOW_COPY_AND_ASSIGN(RawAudioCache);
};
explicit DataCache() { finished_ = false; }
// it is a datasource for testing different frontend module.
// it accepts waves or feats.
class RawDataCache : public FeatureExtractorInterface {
public:
explicit RawDataCache() { finished_ = false; }
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
data_ = inputs;
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
if (data_.Dim() == 0) {
return false;
@ -80,7 +51,6 @@ class RawDataCache : public FeatureExtractorInterface {
bool finished_;
int32 dim_;
DISALLOW_COPY_AND_ASSIGN(RawDataCache);
DISALLOW_COPY_AND_ASSIGN(DataCache);
};
} // namespace ppspeech
}
Loading…
Cancel
Save