diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 2cef4972..8361f06d 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() { times_.emplace_back(empty); } -void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - +void CTCPrefixBeamSearch::InitDecoder() { + Reset(); +} void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { @@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } @@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 819cc2e8..7306ebf8 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -8,14 +8,21 @@ target_link_libraries(nnet utils) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# test bin -#if(USING_U2) -# set(bin_name u2_nnet_main) -# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) -# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) -# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -#endif() +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + +set(bin_name u2_nnet_thread_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend) + +target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index f01e9049..a140c376 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 6207a6b5..b83b5976 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -22,14 +22,43 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) - : nnet_(nnet), frontend_(frontend) {} + : nnet_(nnet), frontend_(frontend) { + abort_ = false; + Reset(); + thread_ = std::thread(RunNnetEvaluation, this); + } void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); + condition_variable_.notify_one(); +} + +void NnetProducer::UnLock() { + std::unique_lock lock(read_mutex_); + while (frontend_->IsFinished() == false && cache_.empty()) { + condition_read_ready_.wait(lock); + } + return; +} + +void NnetProducer::RunNnetEvaluation(NnetProducer *me) { + me->RunNnetEvaluationInteral(); +} + +void NnetProducer::RunNnetEvaluationInteral() { bool result = false; - do { - result = Compute(); - } while (result); + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(mutex_); + condition_variable_.wait(lock); + do { + result = Compute(); + } while (result); + if (frontend_->IsFinished() == true) { + if (cache_.empty()) finished_ = true; + } + } + LOG(INFO) << "NnetEvaluationInteral exit"; } void NnetProducer::Acceptlikelihood( @@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood( for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { for (size_t col = 0; col < likelihood.NumCols(); ++col) { prob[col] = likelihood(idx, col); - cache_.push_back(prob); } + cache_.push_back(prob); } } bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + condition_variable_.notify_one(); + return flag; +} + +bool NnetProducer::ReadandCompute(std::vector* nnet_prob) { + Compute(); + if (frontend_->IsFinished() && cache_.empty()) finished_ = true; bool flag = cache_.pop(nnet_prob); return flag; } @@ -53,22 +90,23 @@ bool NnetProducer::Compute() { vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. - VLOG(3) << "no feat avalible"; + VLOG(2) << "no feat avalible"; return false; } CHECK_GE(frontend_->Dim(), 0); - VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; size_t nframes = out.logprobs.size() / vocab_dim; - VLOG(2) << "Forward out " << nframes << " decoder frames."; + VLOG(1) << "Forward out " << nframes << " decoder frames."; for (size_t idx = 0; idx < nframes; ++idx) { std::vector logprob( out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + (idx + 1) * vocab_dim); cache_.push_back(logprob); + condition_read_ready_.notify_one(); } return true; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index dd356f95..14c74d04 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -33,27 +33,38 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); + bool ReadandCompute(std::vector* nnet_prob); + static void RunNnetEvaluation(NnetProducer *me); + void RunNnetEvaluationInteral(); + void UnLock(); + + void Wait() { + abort_ = true; + condition_variable_.notify_one(); + if (thread_.joinable()) thread_.join(); + } bool Empty() const { return cache_.empty(); } - void SetFinished() { + void SetInputFinished() { LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); frontend_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); - LOG(INFO) << "compute last feats done."; + condition_variable_.notify_one(); } - bool IsFinished() const { return frontend_->IsFinished(); } + // the compute thread exit + bool IsFinished() const { return finished_; } + + ~NnetProducer() { + if (thread_.joinable()) thread_.join(); + } void Reset() { frontend_->Reset(); nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); + finished_ = false; } void AttentionRescoring(const std::vector>& hyps, @@ -66,6 +77,13 @@ class NnetProducer { std::shared_ptr frontend_; std::shared_ptr nnet_; SafeQueue> cache_; + std::mutex mutex_; + std::mutex read_mutex_; + std::condition_variable condition_variable_; + std::condition_variable condition_read_ready_; + std::thread thread_; + bool finished_; + bool abort_; DISALLOW_COPY_AND_ASSIGN(NnetProducer); }; diff --git a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 00000000..ce523e59 --- /dev/null +++ b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "frontend/feature_pipeline.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); + std::shared_ptr feature_pipeline( + new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer( + new ppspeech::NnetProducer(nnet, feature_pipeline)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while(1) { + std::vector logprobs; + bool isok = nnet_producer->Read(&logprobs); + if (nnet_producer->IsFinished()) break; + if (isok == false) continue; + prob_vec.push_back(logprobs); + } + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].size(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx]; + } + } + nnet_out_writer.Write(utt, nnet_out); + } + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index a7644430..0c5a8941 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) unit_table_ = decoder_->VocabTable(); symbol_table_ = unit_table_; + global_frame_offset_ = 0; input_finished_ = false; + num_frames_ = 0; + result_.clear(); + +} + +U2Recognizer::~U2Recognizer() { + SetInputFinished(); + WaitDecodeFinished(); +} - Reset(); +void U2Recognizer::WaitDecodeFinished() { + if (thread_.joinable()) thread_.join(); } -void U2Recognizer::Reset() { +void U2Recognizer::WaitFinished() { + if (thread_.joinable()) thread_.join(); + nnet_producer_->Wait(); +} + +void U2Recognizer::InitDecoder() { global_frame_offset_ = 0; input_finished_ = false; num_frames_ = 0; @@ -52,6 +68,7 @@ void U2Recognizer::Reset() { decodable_->Reset(); decoder_->Reset(); + thread_ = std::thread(RunDecoderSearch, this); } void U2Recognizer::ResetContinuousDecoding() { @@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() { decoder_->Reset(); } +void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { + me->RunDecoderSearchInternal(); +} + +void U2Recognizer::RunDecoderSearchInternal() { + LOG(INFO) << "DecoderSearchInteral begin"; + while (!nnet_producer_->IsFinished()) { + nnet_producer_->UnLock(); + decoder_->AdvanceDecode(decodable_); + } + Decode(); + LOG(INFO) << "DecoderSearchInteral exit"; +} void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; @@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector& waves) { << " samples."; } - void U2Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); UpdateResult(false); @@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } -void U2Recognizer::SetFinished() { - nnet_producer_->SetFinished(); +void U2Recognizer::SetInputFinished() { + nnet_producer_->SetInputFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index a3bf8aea..57f2c9c5 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,19 +112,21 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); + ~U2Recognizer(); + void InitDecoder(); void ResetContinuousDecoding(); void Accept(const std::vector& waves); void Decode(); void Rescoring(); - std::string GetFinalResult(); std::string GetPartialResult(); - void SetFinished(); + void SetInputFinished(); bool IsFinished() { return input_finished_; } + void WaitDecodeFinished(); + void WaitFinished(); bool DecodedSomething() const { return !result_.empty() && !result_[0].sentence.empty(); @@ -137,18 +139,17 @@ class U2Recognizer { // feature_pipeline_->FrameShift(); } - const std::vector& Result() const { return result_; } + void AttentionRescoring(); private: - void AttentionRescoring(); + static void RunDecoderSearch(U2Recognizer *me); + void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); private: U2RecognizerResource opts_; - // std::shared_ptr resource_; - // U2RecognizerResource resource_; std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; @@ -167,6 +168,7 @@ class U2Recognizer { const int time_stamp_gap_ = 100; bool input_finished_; + std::thread thread_; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index 90c7cc06..178c91db 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) { ppspeech::U2Recognizer recognizer(resource); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer.InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -79,7 +80,7 @@ int main(int argc, char* argv[]) { recognizer.Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); + recognizer.SetInputFinished(); } recognizer.Decode(); if (recognizer.DecodedSomething()) { @@ -100,7 +101,6 @@ int main(int argc, char* argv[]) { std::string result = recognizer.GetFinalResult(); - recognizer.Reset(); if (result.empty()) { // the TokenWriter can not write empty string. diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index a53b4541..3f45294d 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -void decode_func(std::shared_ptr recognizer) { - while (!recognizer->IsFinished()) { - recognizer->Decode(); - usleep(100); - } - recognizer->Decode(); - recognizer->Rescoring(); -} - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -59,7 +51,7 @@ int main(int argc, char* argv[]) { new ppspeech::U2Recognizer(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { - std::thread recognizer_thread(decode_func, recognizer_ptr); + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -74,7 +66,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { @@ -85,21 +76,23 @@ int main(int argc, char* argv[]) { for (int i = 0; i < cur_chunk_size; ++i) { wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); recognizer_ptr->Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer_ptr->SetFinished(); + recognizer_ptr->SetInputFinished(); } // no overlap sample_offset += cur_chunk_size; } CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); - recognizer_thread.join(); std::string result = recognizer_ptr->GetFinalResult(); - recognizer_ptr->Reset(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -107,6 +100,7 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); @@ -115,9 +109,11 @@ int main(int argc, char* argv[]) { ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/common/frontend/compute_fbank_main.cc b/speechx/speechx/common/frontend/compute_fbank_main.cc index d7d5165c..e022207d 100644 --- a/speechx/speechx/common/frontend/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/compute_fbank_main.cc @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc index e6ac3c23..c166bd64 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -20,10 +20,9 @@ using kaldi::BaseFloat; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } @@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts, void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms + // feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + if (cache_.empty()) return false; // read from cache *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; return true; } @@ -73,23 +63,15 @@ bool FeatureCache::Compute() { kaldi::Timer timer; int32 num_chunk = feature.size() / dim_; - nframe_ += num_chunk; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; vector feature_chunk(feature.data() + start, feature.data() + start + dim_); - - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - // cache full, wait - ready_feed_condition_.wait(lock); - } - // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; } VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " @@ -97,4 +79,4 @@ bool FeatureCache::Compute() { return true; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/common/frontend/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h index 51816a1d..b87612d6 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -19,16 +19,10 @@ namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + std::unique_lock lock(mutex_); LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); - base_extractor_->SetFinished(); - // read the last chunk data Compute(); - // ready_feed_condition_.notify_one(); + base_extractor_->SetFinished(); LOG(INFO) << "compute last feats done."; } @@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface { int32 dim_; size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms - std::vector remained_feature_; std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache); diff --git a/speechx/speechx/common/frontend/feature_pipeline.cc b/speechx/speechx/common/frontend/feature_pipeline.cc index 34e55a10..f37b4180 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/feature_pipeline.cc @@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/common/frontend/feature_pipeline.h b/speechx/speechx/common/frontend/feature_pipeline.h index ea7e2bba..c9a649fd 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.h +++ b/speechx/speechx/common/frontend/feature_pipeline.h @@ -39,7 +39,6 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; knf::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() {