diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 819cc2e89..7306ebf8a 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -8,14 +8,21 @@ target_link_libraries(nnet utils) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# test bin -#if(USING_U2) -# set(bin_name u2_nnet_main) -# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) -# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) -# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -#endif() +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + +set(bin_name u2_nnet_thread_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend) + +target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index f01e90493..a140c376a 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 6207a6b5a..1db93ce32 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -22,14 +22,36 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) - : nnet_(nnet), frontend_(frontend) {} + : nnet_(nnet), frontend_(frontend) { + abort_ = false; + Reset(); + thread_ = std::thread(RunNnetEvaluation, this); + } void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); + condition_variable_.notify_one(); +} + +void NnetProducer::RunNnetEvaluation(NnetProducer *me) { + me->RunNnetEvaluationInteral(); +} + +void NnetProducer::RunNnetEvaluationInteral() { bool result = false; - do { - result = Compute(); - } while (result); + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(mutex_); + condition_variable_.wait(lock); + do { + result = Compute(); + } while (result); + if (frontend_->IsFinished() == true) { + Compute(); + if (cache_.empty()) finished_ = true; + } + } + LOG(INFO) << "NnetEvaluationInteral exit"; } void NnetProducer::Acceptlikelihood( @@ -45,6 +67,14 @@ void NnetProducer::Acceptlikelihood( } bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + condition_variable_.notify_one(); + return flag; +} + +bool NnetProducer::ReadandCompute(std::vector* nnet_prob) { + Compute(); + if (frontend_->IsFinished() && cache_.empty()) finished_ = true; bool flag = cache_.pop(nnet_prob); return flag; } @@ -53,17 +83,18 @@ bool NnetProducer::Compute() { vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. - VLOG(3) << "no feat avalible"; + VLOG(2) << "no feat avalible"; return false; } CHECK_GE(frontend_->Dim(), 0); - VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + VLOG(1) << features[0] << " . . . " << features[features.size()-1]; NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; size_t nframes = out.logprobs.size() / vocab_dim; - VLOG(2) << "Forward out " << nframes << " decoder frames."; + VLOG(1) << "Forward out " << nframes << " decoder frames."; for (size_t idx = 0; idx < nframes; ++idx) { std::vector logprob( out.logprobs.data() + idx * vocab_dim, diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index dd356f957..35406f5fc 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -33,27 +33,40 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); + bool ReadandCompute(std::vector* nnet_prob); + static void RunNnetEvaluation(NnetProducer *me); + void RunNnetEvaluationInteral(); + + void Wait() { + abort_ = true; + condition_variable_.notify_one(); + if (thread_.joinable()) thread_.join(); + } bool Empty() const { return cache_.empty(); } - void SetFinished() { + void SetInputFinished() { LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); frontend_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); + condition_variable_.notify_one(); LOG(INFO) << "compute last feats done."; } - bool IsFinished() const { return frontend_->IsFinished(); } + // the compute thread exit + bool IsFinished() const { return finished_; } + + ~NnetProducer() { + if (thread_.joinable()) thread_.join(); + } void Reset() { + //if (thread_.joinable()) thread_.join(); frontend_->Reset(); nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); + finished_ = false; + //thread_ = std::thread(RunNnetEvaluation, this); } void AttentionRescoring(const std::vector>& hyps, @@ -66,6 +79,11 @@ class NnetProducer { std::shared_ptr frontend_; std::shared_ptr nnet_; SafeQueue> cache_; + std::mutex mutex_; + std::condition_variable condition_variable_; + std::thread thread_; + bool finished_; + bool abort_; DISALLOW_COPY_AND_ASSIGN(NnetProducer); }; diff --git a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 000000000..c7376a22b --- /dev/null +++ b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "frontend/feature_pipeline.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); + std::shared_ptr raw_data(new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer(new ppspeech::NnetProducer(nnet, raw_data)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while(1) { + std::vector logprobs; + //nnet_producer->ReadandCompute(&logprobs); + bool isok = nnet_producer->Read(&logprobs); + //bool isok = nnet_producer->IsFinished(); + int vocab_dim = logprobs.size(); + if (nnet_producer->IsFinished()) break; + //for (int row_idx = 0; row_idx < logprobs.Dim() / vocab_dim; + //row_idx++) { + //kaldi::Vector vec_tmp(vocab_dim); + //std::memcpy(vec_tmp.Data(), + //logprobs.Data() + row_idx * vocab_dim, + //sizeof(kaldi::BaseFloat) * vocab_dim); + //prob_vec.push_back(vec_tmp); + //} + } + + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index a76444305..03acf0595 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) unit_table_ = decoder_->VocabTable(); symbol_table_ = unit_table_; + global_frame_offset_ = 0; input_finished_ = false; + num_frames_ = 0; + result_.clear(); + +} + +U2Recognizer::~U2Recognizer() { + SetInputFinished(); + WaitDecodeFinished(); +} - Reset(); +void U2Recognizer::WaitDecodeFinished() { + if (thread_.joinable()) thread_.join(); } -void U2Recognizer::Reset() { +void U2Recognizer::WaitFinished() { + if (thread_.joinable()) thread_.join(); + nnet_producer_->Wait(); +} + +void U2Recognizer::InitDecoder() { global_frame_offset_ = 0; input_finished_ = false; num_frames_ = 0; @@ -52,6 +68,7 @@ void U2Recognizer::Reset() { decodable_->Reset(); decoder_->Reset(); + thread_ = std::thread(RunDecoderSearch, this); } void U2Recognizer::ResetContinuousDecoding() { @@ -63,6 +80,18 @@ void U2Recognizer::ResetContinuousDecoding() { decoder_->Reset(); } +void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { + me->RunDecoderSearchInternal(); +} + +void U2Recognizer::RunDecoderSearchInternal() { + while(!nnet_producer_->IsFinished()) { + Decode(); + } + Decode(); + Rescoring(); + LOG(INFO) << "DecoderSearchInteral exit"; +} void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; @@ -71,7 +100,6 @@ void U2Recognizer::Accept(const vector& waves) { << " samples."; } - void U2Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); UpdateResult(false); @@ -207,8 +235,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } -void U2Recognizer::SetFinished() { - nnet_producer_->SetFinished(); +void U2Recognizer::SetInputFinished() { + nnet_producer_->SetInputFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index a3bf8aeae..8b5add872 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,19 +112,21 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); + ~U2Recognizer(); + void InitDecoder(); void ResetContinuousDecoding(); void Accept(const std::vector& waves); void Decode(); void Rescoring(); - std::string GetFinalResult(); std::string GetPartialResult(); - void SetFinished(); + void SetInputFinished(); bool IsFinished() { return input_finished_; } + void WaitDecodeFinished(); + void WaitFinished(); bool DecodedSomething() const { return !result_.empty() && !result_[0].sentence.empty(); @@ -137,10 +139,11 @@ class U2Recognizer { // feature_pipeline_->FrameShift(); } - const std::vector& Result() const { return result_; } private: + static void RunDecoderSearch(U2Recognizer *me); + void RunDecoderSearchInternal(); void AttentionRescoring(); void UpdateResult(bool finish = false); @@ -167,6 +170,7 @@ class U2Recognizer { const int time_stamp_gap_ = 100; bool input_finished_; + std::thread thread_; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index 90c7cc063..178c91db1 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) { ppspeech::U2Recognizer recognizer(resource); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer.InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -79,7 +80,7 @@ int main(int argc, char* argv[]) { recognizer.Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); + recognizer.SetInputFinished(); } recognizer.Decode(); if (recognizer.DecodedSomething()) { @@ -100,7 +101,6 @@ int main(int argc, char* argv[]) { std::string result = recognizer.GetFinalResult(); - recognizer.Reset(); if (result.empty()) { // the TokenWriter can not write empty string. diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index a53b45415..891b2012a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -22,14 +22,14 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -void decode_func(std::shared_ptr recognizer) { +/*void decode_func(std::shared_ptr recognizer) { while (!recognizer->IsFinished()) { recognizer->Decode(); - usleep(100); } recognizer->Decode(); recognizer->Rescoring(); -} + LOG(INFO) << "decode thread exit!!!"; +}*/ int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); @@ -59,7 +59,7 @@ int main(int argc, char* argv[]) { new ppspeech::U2Recognizer(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { - std::thread recognizer_thread(decode_func, recognizer_ptr); + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -89,17 +89,15 @@ int main(int argc, char* argv[]) { recognizer_ptr->Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer_ptr->SetFinished(); + recognizer_ptr->SetInputFinished(); } // no overlap sample_offset += cur_chunk_size; } CHECK(sample_offset == tot_samples); - - recognizer_thread.join(); + recognizer_ptr->WaitDecodeFinished(); std::string result = recognizer_ptr->GetFinalResult(); - recognizer_ptr->Reset(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -107,6 +105,7 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); @@ -115,6 +114,7 @@ int main(int argc, char* argv[]) { ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc index e6ac3c23c..bf76aaff4 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -31,34 +31,35 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts, void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms +// feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + + //while (cache_.empty() && base_extractor_->IsFinished() == false) { + //// todo refactor: wait + //// ready_read_condition_.wait(lock); + //int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms + //if (elapsed > timeout_) { + //return false; + //} + //usleep(100); // sleep 0.1 ms + //} if (cache_.empty()) return false; // read from cache *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); + //ready_feed_condition_.notify_one(); VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; return true; } @@ -73,7 +74,6 @@ bool FeatureCache::Compute() { kaldi::Timer timer; int32 num_chunk = feature.size() / dim_; - nframe_ += num_chunk; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { @@ -81,15 +81,16 @@ bool FeatureCache::Compute() { vector feature_chunk(feature.data() + start, feature.data() + start + dim_); - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { + // std::unique_lock lock(mutex_); + //while (cache_.size() >= max_size_) { // cache full, wait - ready_feed_condition_.wait(lock); - } + // ready_feed_condition_.wait(lock); + //} // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; + //ready_read_condition_.notify_one(); } VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " diff --git a/speechx/speechx/common/frontend/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h index 51816a1de..891e62e60 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -41,13 +41,11 @@ class FeatureCache : public FrontendInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + std::unique_lock lock(mutex_); LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); - base_extractor_->SetFinished(); - // read the last chunk data Compute(); - // ready_feed_condition_.notify_one(); + base_extractor_->SetFinished(); LOG(INFO) << "compute last feats done."; } @@ -71,11 +69,8 @@ class FeatureCache : public FrontendInterface { std::unique_ptr base_extractor_; kaldi::int32 timeout_; // ms - std::vector remained_feature_; std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache);