diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 2cef4972d..8361f06d6 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() { times_.emplace_back(empty); } -void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - +void CTCPrefixBeamSearch::InitDecoder() { + Reset(); +} void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { @@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } @@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 1db93ce32..31a5555c9 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -33,6 +33,14 @@ void NnetProducer::Accept(const std::vector& inputs) { condition_variable_.notify_one(); } +void NnetProducer::UnLock() { + std::unique_lock lock(read_mutex_); + while (frontend_->IsFinished() == false && cache_.empty()) { + condition_read_ready_.wait(lock); + } + return; +} + void NnetProducer::RunNnetEvaluation(NnetProducer *me) { me->RunNnetEvaluationInteral(); } @@ -47,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() { result = Compute(); } while (result); if (frontend_->IsFinished() == true) { - Compute(); + //Compute(); if (cache_.empty()) finished_ = true; } } @@ -61,8 +69,8 @@ void NnetProducer::Acceptlikelihood( for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { for (size_t col = 0; col < likelihood.NumCols(); ++col) { prob[col] = likelihood(idx, col); - cache_.push_back(prob); } + cache_.push_back(prob); } } @@ -100,6 +108,7 @@ bool NnetProducer::Compute() { out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + (idx + 1) * vocab_dim); cache_.push_back(logprob); + condition_read_ready_.notify_one(); } return true; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index 35406f5fc..14c74d043 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -36,6 +36,7 @@ class NnetProducer { bool ReadandCompute(std::vector* nnet_prob); static void RunNnetEvaluation(NnetProducer *me); void RunNnetEvaluationInteral(); + void UnLock(); void Wait() { abort_ = true; @@ -49,7 +50,6 @@ class NnetProducer { LOG(INFO) << "set finished"; frontend_->SetFinished(); condition_variable_.notify_one(); - LOG(INFO) << "compute last feats done."; } // the compute thread exit @@ -60,13 +60,11 @@ class NnetProducer { } void Reset() { - //if (thread_.joinable()) thread_.join(); frontend_->Reset(); nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); finished_ = false; - //thread_ = std::thread(RunNnetEvaluation, this); } void AttentionRescoring(const std::vector>& hyps, @@ -80,7 +78,9 @@ class NnetProducer { std::shared_ptr nnet_; SafeQueue> cache_; std::mutex mutex_; + std::mutex read_mutex_; std::condition_variable condition_variable_; + std::condition_variable condition_read_ready_; std::thread thread_; bool finished_; bool abort_; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index 03acf0595..0c5a8941d 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -85,11 +85,12 @@ void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { } void U2Recognizer::RunDecoderSearchInternal() { - while(!nnet_producer_->IsFinished()) { - Decode(); + LOG(INFO) << "DecoderSearchInteral begin"; + while (!nnet_producer_->IsFinished()) { + nnet_producer_->UnLock(); + decoder_->AdvanceDecode(decodable_); } Decode(); - Rescoring(); LOG(INFO) << "DecoderSearchInteral exit"; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 8b5add872..57f2c9c56 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -140,18 +140,16 @@ class U2Recognizer { } const std::vector& Result() const { return result_; } + void AttentionRescoring(); private: static void RunDecoderSearch(U2Recognizer *me); void RunDecoderSearchInternal(); - void AttentionRescoring(); void UpdateResult(bool finish = false); private: U2RecognizerResource opts_; - // std::shared_ptr resource_; - // U2RecognizerResource resource_; std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index 891b2012a..3f45294d1 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -/*void decode_func(std::shared_ptr recognizer) { - while (!recognizer->IsFinished()) { - recognizer->Decode(); - } - recognizer->Decode(); - recognizer->Rescoring(); - LOG(INFO) << "decode thread exit!!!"; -}*/ - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -74,7 +66,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { @@ -85,7 +76,6 @@ int main(int argc, char* argv[]) { for (int i = 0; i < cur_chunk_size; ++i) { wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); recognizer_ptr->Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { @@ -97,6 +87,11 @@ int main(int argc, char* argv[]) { } CHECK(sample_offset == tot_samples); recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); + std::string result = recognizer_ptr->GetFinalResult(); if (result.empty()) { // the TokenWriter can not write empty string. @@ -119,5 +114,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/common/frontend/compute_fbank_main.cc b/speechx/speechx/common/frontend/compute_fbank_main.cc index d7d5165ca..e022207d9 100644 --- a/speechx/speechx/common/frontend/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/compute_fbank_main.cc @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc index bf76aaff4..6bce4464e 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -20,10 +20,9 @@ using kaldi::BaseFloat; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } @@ -37,30 +36,20 @@ void FeatureCache::Accept(const std::vector& inputs) { bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; std::unique_lock lock(mutex_); -// feed current data + // feed current data if (cache_.empty()) { bool result = false; do { result = Compute(); } while (result); } - - //while (cache_.empty() && base_extractor_->IsFinished() == false) { - //// todo refactor: wait - //// ready_read_condition_.wait(lock); - //int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - //if (elapsed > timeout_) { - //return false; - //} - //usleep(100); // sleep 0.1 ms - //} if (cache_.empty()) return false; // read from cache *feats = cache_.front(); cache_.pop(); //ready_feed_condition_.notify_one(); - VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; + VLOG(2) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; return true; } @@ -80,17 +69,9 @@ bool FeatureCache::Compute() { int32 start = chunk_idx * dim_; vector feature_chunk(feature.data() + start, feature.data() + start + dim_); - - // std::unique_lock lock(mutex_); - //while (cache_.size() >= max_size_) { - // cache full, wait - // ready_feed_condition_.wait(lock); - //} - // feed cache cache_.push(feature_chunk); ++nframe_; - //ready_read_condition_.notify_one(); } VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " diff --git a/speechx/speechx/common/frontend/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h index 891e62e60..b87612d66 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -19,16 +19,10 @@ namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -64,11 +58,8 @@ class FeatureCache : public FrontendInterface { int32 dim_; size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms std::queue> cache_; // feature cache std::mutex mutex_; diff --git a/speechx/speechx/common/frontend/feature_pipeline.cc b/speechx/speechx/common/frontend/feature_pipeline.cc index 34e55a10c..f37b41807 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/feature_pipeline.cc @@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/common/frontend/feature_pipeline.h b/speechx/speechx/common/frontend/feature_pipeline.h index ea7e2bba3..c9a649fde 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.h +++ b/speechx/speechx/common/frontend/feature_pipeline.h @@ -39,7 +39,6 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; knf::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() {