diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index ed5c38f0..45bf5419 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -45,7 +45,7 @@ option(USE_PROFILING "enable c++ profling" OFF) option(WITH_TESTING "unit test" ON) option(USING_U2 "compile u2 model." ON) -option(USING_DS2 "compile with ds2 model." ON) +option(USING_DS2 "compile with ds2 model." OFF) option(USING_GPU "u2 compute on GPU." OFF) diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index 31c8b19e..31276895 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -18,6 +18,7 @@ #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" #include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); @@ -39,7 +40,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test ds2 online decoder by feeding speech feature +// test u2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -69,8 +70,10 @@ int main(int argc, char* argv[]) { // decodeable std::shared_ptr raw_data = std::make_shared(); + std::shared_ptr nnet_producer = + std::make_shared(nnet, raw_data); std::shared_ptr decodable = - std::make_shared(nnet, raw_data); + std::make_shared(nnet_producer); // decoder ppspeech::CTCBeamSearchOptions opts; @@ -114,9 +117,9 @@ int main(int argc, char* argv[]) { ori_feature_len - chunk_idx * chunk_stride, chunk_size); } if (this_chunk_size < receptive_field_length) { - LOG(WARNING) - << "utt: " << utt << " skip last " << this_chunk_size - << " frames, expect is " << receptive_field_length; + LOG(WARNING) << "utt: " << utt << " skip last " + << this_chunk_size << " frames, expect is " + << receptive_field_length; break; } diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 27081086..2846540e 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -1,4 +1,4 @@ -set(srcs decodable.cc) +set(srcs decodable.cc nnet_producer.cc) if(USING_DS2) list(APPEND srcs ds2_nnet.cc) @@ -27,13 +27,13 @@ if(USING_DS2) endif() # test bin -if(USING_U2) - set(bin_name u2_nnet_main) - add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) - target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) - target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) - - target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) - target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -endif() +#if(USING_U2) +# set(bin_name u2_nnet_main) +# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) + +# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) +#endif() diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index 5fe2b984..f01e9049 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -21,19 +21,16 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, +Decodable::Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale) - : frontend_(frontend), - nnet_(nnet), + : nnet_producer_(nnet_producer), frame_offset_(0), frames_ready_(0), acoustic_scale_(acoustic_scale) {} // for debug void Decodable::Acceptlikelihood(const Matrix& likelihood) { - nnet_out_cache_ = likelihood; - frames_ready_ += likelihood.NumRows(); + nnet_producer_->Acceptlikelihood(likelihood); } @@ -43,7 +40,7 @@ int32 Decodable::NumFramesReady() const { return frames_ready_; } // frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { - bool flag = EnsureFrameHaveComputed(frame); + EnsureFrameHaveComputed(frame); return frame >= frames_ready_; } @@ -64,32 +61,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::AdvanceChunk() { kaldi::Timer timer; - // read feats - Vector features; - if (frontend_ == NULL || frontend_->Read(&features) == false) { - // no feat or frontend_ not init. - VLOG(3) << "decodable exit;"; - return false; - } - CHECK_GE(frontend_->Dim(), 0); - VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec."; - VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; - - // forward feats - NnetOut out; - nnet_->FeedForward(features, frontend_->Dim(), &out); - int32& vocab_dim = out.vocab_dim; - Vector& logprobs = out.logprobs; - - VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim - << " decoder frames."; - // cache nnet outupts - nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim); - nnet_out_cache_.CopyRowsFromVec(logprobs); - - // update state, decoding frame. + bool flag = nnet_producer_->Read(&framelikelihood_); + if (flag == false) return false; frame_offset_ = frames_ready_; - frames_ready_ += nnet_out_cache_.NumRows(); + frames_ready_ += 1; VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() << " sec."; return true; @@ -101,17 +76,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - if (nrows <= 0) { + if (framelikelihood_.empty()) { LOG(WARNING) << "No new nnet out in cache."; return false; } - logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); - logprobs->CopyRowsFromMat(nnet_out_cache_); - - *vocab_dim = nnet_out_cache_.NumCols(); + size_t dim = framelikelihood_.size(); + logprobs->Resize(framelikelihood_.size()); + std::memcpy(logprobs->Data(), + framelikelihood_.data(), + dim * sizeof(kaldi::BaseFloat)); + *vocab_dim = framelikelihood_.size(); return true; } @@ -122,19 +97,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - int vocab_size = nnet_out_cache_.NumCols(); - likelihood->resize(vocab_size); - - for (int32 idx = 0; idx < vocab_size; ++idx) { - (*likelihood)[idx] = - nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_; - - VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " " - << nnet_out_cache_.NumRows() - << " logprob: " << nnet_out_cache_(frame - frame_offset_, idx); - } + CHECK_EQ(1, (frames_ready_ - frame_offset_)); + *likelihood = framelikelihood_; return true; } @@ -143,37 +107,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return false; } - CHECK_LE(index, nnet_out_cache_.NumCols()); + CHECK_LE(index, framelikelihood_.size()); CHECK_LE(frame, frames_ready_); // the nnet output is prob ranther than log prob // the index - 1, because the ilabel BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; - BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); - if (nnet_->IsLogProb()) { - logprob = nnet_out; - } else { - logprob = std::log(nnet_out + std::numeric_limits::epsilon()); - } - CHECK(!std::isnan(logprob) && !std::isinf(logprob)); + CHECK_EQ(frame_idx, 0); + logprob = framelikelihood_[TokenId2NnetId(index)]; return acoustic_scale_ * logprob; } void Decodable::Reset() { - if (frontend_ != nullptr) frontend_->Reset(); - if (nnet_ != nullptr) nnet_->Reset(); + if (nnet_producer_ != nullptr) nnet_producer_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_out_cache_.Resize(0, 0); + framelikelihood_.clear(); } void Decodable::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { kaldi::Timer timer; - nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); + nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score); VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/asr/nnet/decodable.h b/speechx/speechx/asr/nnet/decodable.h index dd7b329e..cd498e42 100644 --- a/speechx/speechx/asr/nnet/decodable.h +++ b/speechx/speechx/asr/nnet/decodable.h @@ -13,10 +13,10 @@ // limitations under the License. #include "base/common.h" -#include "frontend/audio/frontend_itf.h" #include "kaldi/decoder/decodable-itf.h" #include "kaldi/matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" +#include "nnet/nnet_producer.h" namespace ppspeech { @@ -24,8 +24,7 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, + explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); // void Init(DecodableOpts config); @@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface { void Reset(); - bool IsInputFinished() const { return frontend_->IsFinished(); } + bool IsInputFinished() const { return nnet_producer_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } - // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: - std::shared_ptr frontend_; - std::shared_ptr nnet_; - - // nnet outputs' cache - kaldi::Matrix nnet_out_cache_; + std::shared_ptr nnet_producer_; // the frame is nnet prob frame rather than audio feature frame // nnet frame subsample the feature frame @@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface { // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_; + std::vector framelikelihood_; kaldi::BaseFloat acoustic_scale_; }; diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc new file mode 100644 index 00000000..3a0c4f18 --- /dev/null +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet_producer.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::BaseFloat; + +NnetProducer::NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend) + : nnet_(nnet), frontend_(frontend) {} + +void NnetProducer::Accept(const kaldi::VectorBase& inputs) { + frontend_->Accept(inputs); + bool result = false; + do { + result = Compute(); + } while (result); +} + +void NnetProducer::Acceptlikelihood( + const kaldi::Matrix& likelihood) { + std::vector prob; + prob.resize(likelihood.NumCols()); + for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { + for (size_t col = 0; col < likelihood.NumCols(); ++col) { + prob[col] = likelihood(idx, col); + cache_.push_back(prob); + } + } +} + +bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + return flag; +} + +bool NnetProducer::Compute() { + Vector features; + if (frontend_ == NULL || frontend_->Read(&features) == false) { + // no feat or frontend_ not init. + VLOG(3) << "no feat avalible"; + return false; + } + CHECK_GE(frontend_->Dim(), 0); + VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; + + NnetOut out; + nnet_->FeedForward(features, frontend_->Dim(), &out); + int32& vocab_dim = out.vocab_dim; + Vector& logprobs = out.logprobs; + size_t nframes = logprobs.Dim() / vocab_dim; + VLOG(2) << "Forward out " << nframes << " decoder frames."; + std::vector logprob(vocab_dim); + // remove later. + for (size_t idx = 0; idx < nframes; ++idx) { + for (size_t prob_idx = 0; prob_idx < vocab_dim; ++prob_idx) { + logprob[prob_idx] = logprobs(idx * vocab_dim + prob_idx); + } + cache_.push_back(logprob); + } + return true; +} + +void NnetProducer::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h new file mode 100644 index 00000000..65e9116f --- /dev/null +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "base/safe_queue.h" +#include "frontend/audio/frontend_itf.h" +#include "nnet/nnet_itf.h" + +namespace ppspeech { + +class NnetProducer { + public: + explicit NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend = NULL); + + // Feed feats or waves + void Accept(const kaldi::VectorBase& inputs); + + void Acceptlikelihood(const kaldi::Matrix& likelihood); + + // nnet + bool Read(std::vector* nnet_prob); + + bool Empty() const { return cache_.empty(); } + + void SetFinished() { + LOG(INFO) << "set finished"; + // std::unique_lock lock(mutex_); + frontend_->SetFinished(); + + // read the last chunk data + Compute(); + // ready_feed_condition_.notify_one(); + LOG(INFO) << "compute last feats done."; + } + + bool IsFinished() const { return frontend_->IsFinished(); } + + void Reset() { + frontend_->Reset(); + nnet_->Reset(); + VLOG(3) << "feature cache reset: cache size: " << cache_.size(); + cache_.clear(); + } + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score); + + private: + bool Compute(); + + std::shared_ptr frontend_; + std::shared_ptr nnet_; + SafeQueue> cache_; + + DISALLOW_COPY_AND_ASSIGN(NnetProducer); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/asr/recognizer/CMakeLists.txt b/speechx/speechx/asr/recognizer/CMakeLists.txt index 05078873..53e2e58d 100644 --- a/speechx/speechx/asr/recognizer/CMakeLists.txt +++ b/speechx/speechx/asr/recognizer/CMakeLists.txt @@ -30,6 +30,7 @@ endif() if (USING_U2) set(TEST_BINS u2_recognizer_main + u2_recognizer_thread_main ) foreach(bin_name IN LISTS TEST_BINS) diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index d1d308eb..ea62ae1a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -27,13 +27,13 @@ using std::vector; U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) : opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); - + std::shared_ptr feature_pipeline( + new FeaturePipeline(feature_opts)); std::shared_ptr nnet(new U2Nnet(resource.model_opts)); - - BaseFloat am_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline)); + decodable_.reset(new Decodable(nnet_producer_, am_scale)); CHECK_NE(resource.vocab_path, ""); decoder_.reset(new CTCPrefixBeamSearch( @@ -49,6 +49,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) void U2Recognizer::Reset() { global_frame_offset_ = 0; + input_finished_ = false; num_frames_ = 0; result_.clear(); @@ -68,7 +69,7 @@ void U2Recognizer::ResetContinuousDecoding() { void U2Recognizer::Accept(const VectorBase& waves) { kaldi::Timer timer; - feature_pipeline_->Accept(waves); + nnet_producer_->Accept(waves); VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() << " samples."; } @@ -210,7 +211,7 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } void U2Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); + nnet_producer_->SetFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index 25850863..855d161a 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -130,11 +130,11 @@ class U2Recognizer { return !result_.empty() && !result_[0].sentence.empty(); } - int FrameShiftInMs() const { - // one decoder frame length in ms - return decodable_->Nnet()->SubsamplingRate() * - feature_pipeline_->FrameShift(); + // one decoder frame length in ms, todo + return 1; + // return decodable_->Nnet()->SubsamplingRate() * + // feature_pipeline_->FrameShift(); } @@ -149,7 +149,7 @@ class U2Recognizer { // std::shared_ptr resource_; // U2RecognizerResource resource_; - std::shared_ptr feature_pipeline_; + std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc new file mode 100644 index 00000000..e73efef1 --- /dev/null +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/u2_recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +void decode_func(std::shared_ptr recognizer) { + while (!recognizer->IsFinished()) { + recognizer->Decode(); + usleep(100); + } + recognizer->Decode(); + recognizer->Rescoring(); +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_decode_time = 0.0; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); + std::shared_ptr recognizer_ptr( + new ppspeech::U2Recognizer(resource)); + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::thread recognizer_thread(decode_func, recognizer_ptr); + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); + + recognizer_ptr->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer_ptr->SetFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + recognizer_thread.join(); + std::string result = recognizer_ptr->GetFinalResult(); + recognizer_ptr->Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + continue; + } + + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + result_writer.Write(utt, result); + + ++num_done; + } + + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} diff --git a/speechx/speechx/asr/server/CMakeLists.txt b/speechx/speechx/asr/server/CMakeLists.txt index 71b33daa..566b42ee 100644 --- a/speechx/speechx/asr/server/CMakeLists.txt +++ b/speechx/speechx/asr/server/CMakeLists.txt @@ -1 +1 @@ -add_subdirectory(websocket) +#add_subdirectory(websocket) diff --git a/speechx/speechx/common/base/common.h b/speechx/speechx/common/base/common.h index 97bff966..2a066ee6 100644 --- a/speechx/speechx/common/base/common.h +++ b/speechx/speechx/common/base/common.h @@ -48,4 +48,4 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" \ No newline at end of file +#include "utils/math.h" diff --git a/speechx/speechx/common/base/safe_queue.h b/speechx/speechx/common/base/safe_queue.h new file mode 100644 index 00000000..25a012af --- /dev/null +++ b/speechx/speechx/common/base/safe_queue.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" + +namespace ppspeech { + +template +class SafeQueue { + public: + explicit SafeQueue(size_t capacity = 0); + void push_back(const T& in); + bool pop(T* out); + bool empty() const { return buffer_.empty(); } + size_t size() const { return buffer_.size(); } + void clear(); + + + private: + std::mutex mutex_; + std::condition_variable condition_; + std::deque buffer_; + size_t capacity_; +}; + +template +SafeQueue::SafeQueue(size_t capacity) : capacity_(capacity) {} + +template +void SafeQueue::push_back(const T& in) { + std::unique_lock lock(mutex_); + if (capacity_ > 0 && buffer_.size() == capacity_) { + condition_.wait(lock, [this] { return capacity_ >= buffer_.size(); }); + } + + buffer_.push_back(in); + condition_.notify_one(); +} + +template +bool SafeQueue::pop(T* out) { + if (buffer_.empty()) { + return false; + } + + std::unique_lock lock(mutex_); + condition_.wait(lock, [this] { return buffer_.size() > 0; }); + *out = std::move(buffer_.front()); + buffer_.pop_front(); + condition_.notify_one(); + return true; +} + +template +void SafeQueue::clear() { + std::unique_lock lock(mutex_); + buffer_.clear(); + condition_.notify_one(); +} +} // namespace ppspeech