[speechx] thread decode (#2839)

* fix nnet thread crash && rescore cost time

* add nnet thread main
pull/2886/head
YangZhou 2 years ago committed by GitHub
parent ee7c266f13
commit 8a225b1708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() {
times_.emplace_back(empty);
}
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::InitDecoder() {
Reset();
}
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
feat_nnet_cost += timer.Elapsed();
if (flag == false) {
VLOG(3) << "decoder advance decode exit." << frame_prob.size();
VLOG(2) << "decoder advance decode exit." << frame_prob.size();
break;
}
@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode(
AdvanceDecoding(likelihood);
search_cost += timer.Elapsed();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_;
}
VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost
<< " sec.";

@ -8,14 +8,21 @@ target_link_libraries(nnet utils)
target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS})
target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# test bin
#if(USING_U2)
# set(bin_name u2_nnet_main)
# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# test bin
#set(bin_name u2_nnet_main)
#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet)
# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
#endif()
#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
set(bin_name u2_nnet_thread_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})

@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix<BaseFloat>& likelihood) {
nnet_producer_->Acceptlikelihood(likelihood);
}
// return the size of frame have computed.
int32 Decodable::NumFramesReady() const { return frames_ready_; }

@ -22,14 +22,43 @@ using kaldi::BaseFloat;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {}
: nnet_(nnet), frontend_(frontend) {
abort_ = false;
Reset();
thread_ = std::thread(RunNnetEvaluation, this);
}
void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
condition_variable_.notify_one();
}
void NnetProducer::UnLock() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
}
return;
}
void NnetProducer::RunNnetEvaluation(NnetProducer *me) {
me->RunNnetEvaluationInteral();
}
void NnetProducer::RunNnetEvaluationInteral() {
bool result = false;
do {
result = Compute();
} while (result);
LOG(INFO) << "NnetEvaluationInteral begin";
while (!abort_) {
std::unique_lock<std::mutex> lock(mutex_);
condition_variable_.wait(lock);
do {
result = Compute();
} while (result);
if (frontend_->IsFinished() == true) {
if (cache_.empty()) finished_ = true;
}
}
LOG(INFO) << "NnetEvaluationInteral exit";
}
void NnetProducer::Acceptlikelihood(
@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood(
for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) {
for (size_t col = 0; col < likelihood.NumCols(); ++col) {
prob[col] = likelihood(idx, col);
cache_.push_back(prob);
}
cache_.push_back(prob);
}
}
bool NnetProducer::Read(std::vector<kaldi::BaseFloat>* nnet_prob) {
bool flag = cache_.pop(nnet_prob);
condition_variable_.notify_one();
return flag;
}
bool NnetProducer::ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob) {
Compute();
if (frontend_->IsFinished() && cache_.empty()) finished_ = true;
bool flag = cache_.pop(nnet_prob);
return flag;
}
@ -53,22 +90,23 @@ bool NnetProducer::Compute() {
vector<BaseFloat> features;
if (frontend_ == NULL || frontend_->Read(&features) == false) {
// no feat or frontend_ not init.
VLOG(3) << "no feat avalible";
VLOG(2) << "no feat avalible";
return false;
}
CHECK_GE(frontend_->Dim(), 0);
VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats.";
NnetOut out;
nnet_->FeedForward(features, frontend_->Dim(), &out);
int32& vocab_dim = out.vocab_dim;
size_t nframes = out.logprobs.size() / vocab_dim;
VLOG(2) << "Forward out " << nframes << " decoder frames.";
VLOG(1) << "Forward out " << nframes << " decoder frames.";
for (size_t idx = 0; idx < nframes; ++idx) {
std::vector<BaseFloat> logprob(
out.logprobs.data() + idx * vocab_dim,
out.logprobs.data() + (idx + 1) * vocab_dim);
cache_.push_back(logprob);
condition_read_ready_.notify_one();
}
return true;
}

@ -33,27 +33,38 @@ class NnetProducer {
// nnet
bool Read(std::vector<kaldi::BaseFloat>* nnet_prob);
bool ReadandCompute(std::vector<kaldi::BaseFloat>* nnet_prob);
static void RunNnetEvaluation(NnetProducer *me);
void RunNnetEvaluationInteral();
void UnLock();
void Wait() {
abort_ = true;
condition_variable_.notify_one();
if (thread_.joinable()) thread_.join();
}
bool Empty() const { return cache_.empty(); }
void SetFinished() {
void SetInputFinished() {
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_);
frontend_->SetFinished();
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
LOG(INFO) << "compute last feats done.";
condition_variable_.notify_one();
}
bool IsFinished() const { return frontend_->IsFinished(); }
// the compute thread exit
bool IsFinished() const { return finished_; }
~NnetProducer() {
if (thread_.joinable()) thread_.join();
}
void Reset() {
frontend_->Reset();
nnet_->Reset();
VLOG(3) << "feature cache reset: cache size: " << cache_.size();
cache_.clear();
finished_ = false;
}
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
@ -66,6 +77,13 @@ class NnetProducer {
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetBase> nnet_;
SafeQueue<std::vector<kaldi::BaseFloat>> cache_;
std::mutex mutex_;
std::mutex read_mutex_;
std::condition_variable condition_variable_;
std::condition_variable condition_read_ready_;
std::thread thread_;
bool finished_;
bool abort_;
DISALLOW_COPY_AND_ASSIGN(NnetProducer);
};

@ -0,0 +1,137 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/wave-reader.h"
#include "frontend/feature_pipeline.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/u2_nnet.h"
#include "nnet/nnet_producer.h"
DEFINE_string(wav_rspecifier, "", "test wav rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
int32 num_done = 0, num_err = 0;
int sample_rate = FLAGS_sample_rate;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0);
CHECK_GT(FLAGS_model_path.size(), 0);
LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier;
LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier;
LOG(INFO) << "model path: " << FLAGS_model_path;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::FeaturePipelineOptions feature_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
feature_opts.assembler_opts.fill_zero = false;
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(model_opts));
std::shared_ptr<ppspeech::FeaturePipeline> feature_pipeline(
new ppspeech::FeaturePipeline(feature_opts));
std::shared_ptr<ppspeech::NnetProducer> nnet_producer(
new ppspeech::NnetProducer(nnet, feature_pipeline));
kaldi::Timer timer;
float tot_wav_duration = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec.";
double dur = wave_data.Duration();
tot_wav_duration += dur;
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer timer;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
nnet_producer->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
nnet_producer->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
std::vector<std::vector<kaldi::BaseFloat>> prob_vec;
while(1) {
std::vector<kaldi::BaseFloat> logprobs;
bool isok = nnet_producer->Read(&logprobs);
if (nnet_producer->IsFinished()) break;
if (isok == false) continue;
prob_vec.push_back(logprobs);
}
{
// writer nnet output
kaldi::MatrixIndexT nrow = prob_vec.size();
kaldi::MatrixIndexT ncol = prob_vec[0].size();
LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol;
kaldi::Matrix<kaldi::BaseFloat> nnet_out(nrow, ncol);
for (int32 row_idx = 0; row_idx < nrow; ++row_idx) {
for (int32 col_idx = 0; col_idx < ncol; ++col_idx) {
nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx];
}
}
nnet_out_writer.Write(utt, nnet_out);
}
nnet_producer->Reset();
}
nnet_producer->Wait();
double elapsed = timer.Elapsed();
LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
unit_table_ = decoder_->VocabTable();
symbol_table_ = unit_table_;
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
result_.clear();
}
U2Recognizer::~U2Recognizer() {
SetInputFinished();
WaitDecodeFinished();
}
Reset();
void U2Recognizer::WaitDecodeFinished() {
if (thread_.joinable()) thread_.join();
}
void U2Recognizer::Reset() {
void U2Recognizer::WaitFinished() {
if (thread_.joinable()) thread_.join();
nnet_producer_->Wait();
}
void U2Recognizer::InitDecoder() {
global_frame_offset_ = 0;
input_finished_ = false;
num_frames_ = 0;
@ -52,6 +68,7 @@ void U2Recognizer::Reset() {
decodable_->Reset();
decoder_->Reset();
thread_ = std::thread(RunDecoderSearch, this);
}
void U2Recognizer::ResetContinuousDecoding() {
@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() {
decoder_->Reset();
}
void U2Recognizer::RunDecoderSearch(U2Recognizer* me) {
me->RunDecoderSearchInternal();
}
void U2Recognizer::RunDecoderSearchInternal() {
LOG(INFO) << "DecoderSearchInteral begin";
while (!nnet_producer_->IsFinished()) {
nnet_producer_->UnLock();
decoder_->AdvanceDecode(decodable_);
}
Decode();
LOG(INFO) << "DecoderSearchInteral exit";
}
void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
kaldi::Timer timer;
@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector<BaseFloat>& waves) {
<< " samples.";
}
void U2Recognizer::Decode() {
decoder_->AdvanceDecode(decodable_);
UpdateResult(false);
@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; }
std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; }
void U2Recognizer::SetFinished() {
nnet_producer_->SetFinished();
void U2Recognizer::SetInputFinished() {
nnet_producer_->SetInputFinished();
input_finished_ = true;
}

@ -112,19 +112,21 @@ struct U2RecognizerResource {
class U2Recognizer {
public:
explicit U2Recognizer(const U2RecognizerResource& resouce);
void Reset();
~U2Recognizer();
void InitDecoder();
void ResetContinuousDecoding();
void Accept(const std::vector<kaldi::BaseFloat>& waves);
void Decode();
void Rescoring();
std::string GetFinalResult();
std::string GetPartialResult();
void SetFinished();
void SetInputFinished();
bool IsFinished() { return input_finished_; }
void WaitDecodeFinished();
void WaitFinished();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
@ -137,18 +139,17 @@ class U2Recognizer {
// feature_pipeline_->FrameShift();
}
const std::vector<DecodeResult>& Result() const { return result_; }
void AttentionRescoring();
private:
void AttentionRescoring();
static void RunDecoderSearch(U2Recognizer *me);
void RunDecoderSearchInternal();
void UpdateResult(bool finish = false);
private:
U2RecognizerResource opts_;
// std::shared_ptr<U2RecognizerResource> resource_;
// U2RecognizerResource resource_;
std::shared_ptr<NnetProducer> nnet_producer_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<CTCPrefixBeamSearch> decoder_;
@ -167,6 +168,7 @@ class U2Recognizer {
const int time_stamp_gap_ = 100;
bool input_finished_;
std::thread thread_;
};
} // namespace ppspeech

@ -49,6 +49,7 @@ int main(int argc, char* argv[]) {
ppspeech::U2Recognizer recognizer(resource);
for (; !wav_reader.Done(); wav_reader.Next()) {
recognizer.InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -79,7 +80,7 @@ int main(int argc, char* argv[]) {
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
recognizer.SetInputFinished();
}
recognizer.Decode();
if (recognizer.DecodedSomething()) {
@ -100,7 +101,6 @@ int main(int argc, char* argv[]) {
std::string result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.

@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
void decode_func(std::shared_ptr<ppspeech::U2Recognizer> recognizer) {
while (!recognizer->IsFinished()) {
recognizer->Decode();
usleep(100);
}
recognizer->Decode();
recognizer->Rescoring();
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
@ -40,6 +31,7 @@ int main(int argc, char* argv[]) {
int32 num_done = 0, num_err = 0;
double tot_wav_duration = 0.0;
double tot_attention_rescore_time = 0.0;
double tot_decode_time = 0.0;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
@ -59,7 +51,7 @@ int main(int argc, char* argv[]) {
new ppspeech::U2Recognizer(resource));
for (; !wav_reader.Done(); wav_reader.Next()) {
std::thread recognizer_thread(decode_func, recognizer_ptr);
recognizer_ptr->InitDecoder();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "utt: " << utt;
@ -74,7 +66,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
kaldi::Timer timer;
kaldi::Timer local_timer;
while (sample_offset < tot_samples) {
@ -85,21 +76,23 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = waveform(sample_offset + i);
}
// wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size);
recognizer_ptr->Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer_ptr->SetFinished();
recognizer_ptr->SetInputFinished();
}
// no overlap
sample_offset += cur_chunk_size;
}
CHECK(sample_offset == tot_samples);
recognizer_ptr->WaitDecodeFinished();
kaldi::Timer timer;
recognizer_ptr->AttentionRescoring();
tot_attention_rescore_time += timer.Elapsed();
recognizer_thread.join();
std::string result = recognizer_ptr->GetFinalResult();
recognizer_ptr->Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
@ -107,6 +100,7 @@ int main(int argc, char* argv[]) {
continue;
}
tot_decode_time += local_timer.Elapsed();
LOG(INFO) << utt << " " << result;
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
<< " cost: " << local_timer.Elapsed();
@ -115,9 +109,11 @@ int main(int argc, char* argv[]) {
++num_done;
}
recognizer_ptr->WaitFinished();
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
}

@ -73,8 +73,7 @@ int main(int argc, char* argv[]) {
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
// the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim();

@ -20,10 +20,9 @@ using kaldi::BaseFloat;
using std::unique_ptr;
using std::vector;
FeatureCache::FeatureCache(FeatureCacheOptions opts,
FeatureCache::FeatureCache(size_t max_size,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = opts.max_size;
timeout_ = opts.timeout; // ms
max_size_ = max_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts,
void FeatureCache::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
// read inputs
base_extractor_->Accept(inputs);
// feed current data
bool result = false;
do {
result = Compute();
} while (result);
}
// pop feature chunk
bool FeatureCache::Read(std::vector<kaldi::BaseFloat>* feats) {
kaldi::Timer timer;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000); // ms
if (elapsed > timeout_) {
return false;
}
usleep(100); // sleep 0.1 ms
// feed current data
if (cache_.empty()) {
bool result = false;
do {
result = Compute();
} while (result);
}
if (cache_.empty()) return false;
// read from cache
*feats = cache_.front();
cache_.pop();
ready_feed_condition_.notify_one();
VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec.";
return true;
}
@ -73,23 +63,15 @@ bool FeatureCache::Compute() {
kaldi::Timer timer;
int32 num_chunk = feature.size() / dim_;
nframe_ += num_chunk;
VLOG(3) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
vector<BaseFloat> feature_chunk(feature.data() + start,
feature.data() + start + dim_);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
// cache full, wait
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
++nframe_;
}
VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. "
@ -97,4 +79,4 @@ bool FeatureCache::Compute() {
return true;
}
} // namespace ppspeech
} // namespace ppspeech

@ -19,16 +19,10 @@
namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 timeout; // ms
FeatureCacheOptions() : max_size(kint16max), timeout(1) {}
};
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
FeatureCacheOptions opts,
size_t max_size = kint16max,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface {
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
std::unique_lock<std::mutex> lock(mutex_);
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished();
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
base_extractor_->SetFinished();
LOG(INFO) << "compute last feats done.";
}
@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface {
int32 dim_;
size_t max_size_; // cache capacity
int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::int32 timeout_; // ms
std::vector<kaldi::BaseFloat> remained_feature_;
std::queue<std::vector<BaseFloat>> cache_; // feature cache
std::mutex mutex_;
std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_;
int32 nframe_; // num of feature computed
DISALLOW_COPY_AND_ASSIGN(FeatureCache);

@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
unique_ptr<FrontendInterface> cache(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
new ppspeech::FeatureCache(kint16max, std::move(cmvn)));
base_extractor_.reset(
new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));

@ -39,7 +39,6 @@ namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file{};
knf::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts{};
static FeaturePipelineOptions InitFromFlags() {

Loading…
Cancel
Save