From 16b9894edf0978d16870151a2022ce6503c799a7 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 18 Jan 2023 16:11:26 +0800 Subject: [PATCH] [speechx] add cls (#2839) * fix nnet thread crash && rescore cost time * add nnet thread main --- speechx/speechx/CMakeLists.txt | 3 + .../decoder/ctc_prefix_beam_search_decoder.cc | 9 +- speechx/speechx/asr/nnet/CMakeLists.txt | 27 +- speechx/speechx/asr/nnet/decodable.cc | 1 - speechx/speechx/asr/nnet/nnet_producer.cc | 54 ++- speechx/speechx/asr/nnet/nnet_producer.h | 34 +- .../speechx/asr/nnet/u2_nnet_thread_main.cc | 137 ++++++ .../speechx/asr/recognizer/u2_recognizer.cc | 39 +- .../speechx/asr/recognizer/u2_recognizer.h | 16 +- .../asr/recognizer/u2_recognizer_main.cc | 4 +- .../recognizer/u2_recognizer_thread_main.cc | 26 +- speechx/speechx/cls/CMakeLists.txt | 46 ++ speechx/speechx/cls/nnet/CMakeLists.txt | 12 + speechx/speechx/cls/nnet/cls_interface.cc | 59 +++ speechx/speechx/cls/nnet/cls_interface.h | 10 + speechx/speechx/cls/nnet/cls_nnet.cc | 405 ++++++++++++++++++ speechx/speechx/cls/nnet/cls_nnet.h | 73 ++++ speechx/speechx/cls/nnet/cls_nnet_main.cc | 45 ++ speechx/speechx/cls/nnet/config.h | 352 +++++++++++++++ speechx/speechx/cls/nnet/wav.h | 241 +++++++++++ .../common/frontend/compute_fbank_main.cc | 3 +- .../speechx/common/frontend/feature_cache.cc | 40 +- .../speechx/common/frontend/feature_cache.h | 20 +- .../common/frontend/feature_pipeline.cc | 2 +- .../common/frontend/feature_pipeline.h | 1 - 25 files changed, 1549 insertions(+), 110 deletions(-) create mode 100644 speechx/speechx/asr/nnet/u2_nnet_thread_main.cc create mode 100644 speechx/speechx/cls/CMakeLists.txt create mode 100644 speechx/speechx/cls/nnet/CMakeLists.txt create mode 100644 speechx/speechx/cls/nnet/cls_interface.cc create mode 100644 speechx/speechx/cls/nnet/cls_interface.h create mode 100644 speechx/speechx/cls/nnet/cls_nnet.cc create mode 100644 speechx/speechx/cls/nnet/cls_nnet.h create mode 100644 speechx/speechx/cls/nnet/cls_nnet_main.cc create mode 100644 speechx/speechx/cls/nnet/config.h create mode 100644 speechx/speechx/cls/nnet/wav.h diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index b522e158c..88e231658 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) project(speechx LANGUAGES CXX) +add_compile_options(-fPIC) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common) @@ -10,3 +12,4 @@ add_subdirectory(asr) add_subdirectory(common) add_subdirectory(kaldi) add_subdirectory(codelab) +add_subdirectory(cls) \ No newline at end of file diff --git a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc index 2cef4972d..8361f06d6 100644 --- a/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc +++ b/speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -63,8 +63,9 @@ void CTCPrefixBeamSearch::Reset() { times_.emplace_back(empty); } -void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - +void CTCPrefixBeamSearch::InitDecoder() { + Reset(); +} void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { @@ -77,7 +78,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } @@ -87,7 +88,7 @@ void CTCPrefixBeamSearch::AdvanceDecode( AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; diff --git a/speechx/speechx/asr/nnet/CMakeLists.txt b/speechx/speechx/asr/nnet/CMakeLists.txt index 819cc2e89..7306ebf8a 100644 --- a/speechx/speechx/asr/nnet/CMakeLists.txt +++ b/speechx/speechx/asr/nnet/CMakeLists.txt @@ -8,14 +8,21 @@ target_link_libraries(nnet utils) target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# test bin -#if(USING_U2) -# set(bin_name u2_nnet_main) -# add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) -# target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -# target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +#target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet) -# target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) -# target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) -# target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) -#endif() +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + +set(bin_name u2_nnet_thread_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet frontend) + +target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) diff --git a/speechx/speechx/asr/nnet/decodable.cc b/speechx/speechx/asr/nnet/decodable.cc index f01e90493..a140c376a 100644 --- a/speechx/speechx/asr/nnet/decodable.cc +++ b/speechx/speechx/asr/nnet/decodable.cc @@ -33,7 +33,6 @@ void Decodable::Acceptlikelihood(const Matrix& likelihood) { nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.cc b/speechx/speechx/asr/nnet/nnet_producer.cc index 6207a6b5a..b83b59767 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.cc +++ b/speechx/speechx/asr/nnet/nnet_producer.cc @@ -22,14 +22,43 @@ using kaldi::BaseFloat; NnetProducer::NnetProducer(std::shared_ptr nnet, std::shared_ptr frontend) - : nnet_(nnet), frontend_(frontend) {} + : nnet_(nnet), frontend_(frontend) { + abort_ = false; + Reset(); + thread_ = std::thread(RunNnetEvaluation, this); + } void NnetProducer::Accept(const std::vector& inputs) { frontend_->Accept(inputs); + condition_variable_.notify_one(); +} + +void NnetProducer::UnLock() { + std::unique_lock lock(read_mutex_); + while (frontend_->IsFinished() == false && cache_.empty()) { + condition_read_ready_.wait(lock); + } + return; +} + +void NnetProducer::RunNnetEvaluation(NnetProducer *me) { + me->RunNnetEvaluationInteral(); +} + +void NnetProducer::RunNnetEvaluationInteral() { bool result = false; - do { - result = Compute(); - } while (result); + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(mutex_); + condition_variable_.wait(lock); + do { + result = Compute(); + } while (result); + if (frontend_->IsFinished() == true) { + if (cache_.empty()) finished_ = true; + } + } + LOG(INFO) << "NnetEvaluationInteral exit"; } void NnetProducer::Acceptlikelihood( @@ -39,12 +68,20 @@ void NnetProducer::Acceptlikelihood( for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { for (size_t col = 0; col < likelihood.NumCols(); ++col) { prob[col] = likelihood(idx, col); - cache_.push_back(prob); } + cache_.push_back(prob); } } bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + condition_variable_.notify_one(); + return flag; +} + +bool NnetProducer::ReadandCompute(std::vector* nnet_prob) { + Compute(); + if (frontend_->IsFinished() && cache_.empty()) finished_ = true; bool flag = cache_.pop(nnet_prob); return flag; } @@ -53,22 +90,23 @@ bool NnetProducer::Compute() { vector features; if (frontend_ == NULL || frontend_->Read(&features) == false) { // no feat or frontend_ not init. - VLOG(3) << "no feat avalible"; + VLOG(2) << "no feat avalible"; return false; } CHECK_GE(frontend_->Dim(), 0); - VLOG(2) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; NnetOut out; nnet_->FeedForward(features, frontend_->Dim(), &out); int32& vocab_dim = out.vocab_dim; size_t nframes = out.logprobs.size() / vocab_dim; - VLOG(2) << "Forward out " << nframes << " decoder frames."; + VLOG(1) << "Forward out " << nframes << " decoder frames."; for (size_t idx = 0; idx < nframes; ++idx) { std::vector logprob( out.logprobs.data() + idx * vocab_dim, out.logprobs.data() + (idx + 1) * vocab_dim); cache_.push_back(logprob); + condition_read_ready_.notify_one(); } return true; } diff --git a/speechx/speechx/asr/nnet/nnet_producer.h b/speechx/speechx/asr/nnet/nnet_producer.h index dd356f957..14c74d043 100644 --- a/speechx/speechx/asr/nnet/nnet_producer.h +++ b/speechx/speechx/asr/nnet/nnet_producer.h @@ -33,27 +33,38 @@ class NnetProducer { // nnet bool Read(std::vector* nnet_prob); + bool ReadandCompute(std::vector* nnet_prob); + static void RunNnetEvaluation(NnetProducer *me); + void RunNnetEvaluationInteral(); + void UnLock(); + + void Wait() { + abort_ = true; + condition_variable_.notify_one(); + if (thread_.joinable()) thread_.join(); + } bool Empty() const { return cache_.empty(); } - void SetFinished() { + void SetInputFinished() { LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); frontend_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); - LOG(INFO) << "compute last feats done."; + condition_variable_.notify_one(); } - bool IsFinished() const { return frontend_->IsFinished(); } + // the compute thread exit + bool IsFinished() const { return finished_; } + + ~NnetProducer() { + if (thread_.joinable()) thread_.join(); + } void Reset() { frontend_->Reset(); nnet_->Reset(); VLOG(3) << "feature cache reset: cache size: " << cache_.size(); cache_.clear(); + finished_ = false; } void AttentionRescoring(const std::vector>& hyps, @@ -66,6 +77,13 @@ class NnetProducer { std::shared_ptr frontend_; std::shared_ptr nnet_; SafeQueue> cache_; + std::mutex mutex_; + std::mutex read_mutex_; + std::condition_variable condition_variable_; + std::condition_variable condition_read_ready_; + std::thread thread_; + bool finished_; + bool abort_; DISALLOW_COPY_AND_ASSIGN(NnetProducer); }; diff --git a/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 000000000..ce523e599 --- /dev/null +++ b/speechx/speechx/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "frontend/feature_pipeline.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); + std::shared_ptr feature_pipeline( + new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer( + new ppspeech::NnetProducer(nnet, feature_pipeline)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while(1) { + std::vector logprobs; + bool isok = nnet_producer->Read(&logprobs); + if (nnet_producer->IsFinished()) break; + if (isok == false) continue; + prob_vec.push_back(logprobs); + } + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].size(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx]; + } + } + nnet_out_writer.Write(utt, nnet_out); + } + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.cc b/speechx/speechx/asr/recognizer/u2_recognizer.cc index a76444305..0c5a8941d 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer.cc @@ -39,12 +39,28 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) unit_table_ = decoder_->VocabTable(); symbol_table_ = unit_table_; + global_frame_offset_ = 0; input_finished_ = false; + num_frames_ = 0; + result_.clear(); + +} + +U2Recognizer::~U2Recognizer() { + SetInputFinished(); + WaitDecodeFinished(); +} - Reset(); +void U2Recognizer::WaitDecodeFinished() { + if (thread_.joinable()) thread_.join(); } -void U2Recognizer::Reset() { +void U2Recognizer::WaitFinished() { + if (thread_.joinable()) thread_.join(); + nnet_producer_->Wait(); +} + +void U2Recognizer::InitDecoder() { global_frame_offset_ = 0; input_finished_ = false; num_frames_ = 0; @@ -52,6 +68,7 @@ void U2Recognizer::Reset() { decodable_->Reset(); decoder_->Reset(); + thread_ = std::thread(RunDecoderSearch, this); } void U2Recognizer::ResetContinuousDecoding() { @@ -63,6 +80,19 @@ void U2Recognizer::ResetContinuousDecoding() { decoder_->Reset(); } +void U2Recognizer::RunDecoderSearch(U2Recognizer* me) { + me->RunDecoderSearchInternal(); +} + +void U2Recognizer::RunDecoderSearchInternal() { + LOG(INFO) << "DecoderSearchInteral begin"; + while (!nnet_producer_->IsFinished()) { + nnet_producer_->UnLock(); + decoder_->AdvanceDecode(decodable_); + } + Decode(); + LOG(INFO) << "DecoderSearchInteral exit"; +} void U2Recognizer::Accept(const vector& waves) { kaldi::Timer timer; @@ -71,7 +101,6 @@ void U2Recognizer::Accept(const vector& waves) { << " samples."; } - void U2Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); UpdateResult(false); @@ -207,8 +236,8 @@ std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } -void U2Recognizer::SetFinished() { - nnet_producer_->SetFinished(); +void U2Recognizer::SetInputFinished() { + nnet_producer_->SetInputFinished(); input_finished_ = true; } diff --git a/speechx/speechx/asr/recognizer/u2_recognizer.h b/speechx/speechx/asr/recognizer/u2_recognizer.h index a3bf8aeae..57f2c9c56 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer.h +++ b/speechx/speechx/asr/recognizer/u2_recognizer.h @@ -112,19 +112,21 @@ struct U2RecognizerResource { class U2Recognizer { public: explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); + ~U2Recognizer(); + void InitDecoder(); void ResetContinuousDecoding(); void Accept(const std::vector& waves); void Decode(); void Rescoring(); - std::string GetFinalResult(); std::string GetPartialResult(); - void SetFinished(); + void SetInputFinished(); bool IsFinished() { return input_finished_; } + void WaitDecodeFinished(); + void WaitFinished(); bool DecodedSomething() const { return !result_.empty() && !result_[0].sentence.empty(); @@ -137,18 +139,17 @@ class U2Recognizer { // feature_pipeline_->FrameShift(); } - const std::vector& Result() const { return result_; } + void AttentionRescoring(); private: - void AttentionRescoring(); + static void RunDecoderSearch(U2Recognizer *me); + void RunDecoderSearchInternal(); void UpdateResult(bool finish = false); private: U2RecognizerResource opts_; - // std::shared_ptr resource_; - // U2RecognizerResource resource_; std::shared_ptr nnet_producer_; std::shared_ptr decodable_; std::unique_ptr decoder_; @@ -167,6 +168,7 @@ class U2Recognizer { const int time_stamp_gap_ = 100; bool input_finished_; + std::thread thread_; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc index 90c7cc063..178c91db1 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_main.cc @@ -49,6 +49,7 @@ int main(int argc, char* argv[]) { ppspeech::U2Recognizer recognizer(resource); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer.InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -79,7 +80,7 @@ int main(int argc, char* argv[]) { recognizer.Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); + recognizer.SetInputFinished(); } recognizer.Decode(); if (recognizer.DecodedSomething()) { @@ -100,7 +101,6 @@ int main(int argc, char* argv[]) { std::string result = recognizer.GetFinalResult(); - recognizer.Reset(); if (result.empty()) { // the TokenWriter can not write empty string. diff --git a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc index a53b45415..3f45294d1 100644 --- a/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc +++ b/speechx/speechx/asr/recognizer/u2_recognizer_thread_main.cc @@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_int32(sample_rate, 16000, "sample rate"); -void decode_func(std::shared_ptr recognizer) { - while (!recognizer->IsFinished()) { - recognizer->Decode(); - usleep(100); - } - recognizer->Decode(); - recognizer->Rescoring(); -} - int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -40,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -59,7 +51,7 @@ int main(int argc, char* argv[]) { new ppspeech::U2Recognizer(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { - std::thread recognizer_thread(decode_func, recognizer_ptr); + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -74,7 +66,6 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { @@ -85,21 +76,23 @@ int main(int argc, char* argv[]) { for (int i = 0; i < cur_chunk_size; ++i) { wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); recognizer_ptr->Accept(wav_chunk); if (cur_chunk_size < chunk_sample_size) { - recognizer_ptr->SetFinished(); + recognizer_ptr->SetInputFinished(); } // no overlap sample_offset += cur_chunk_size; } CHECK(sample_offset == tot_samples); + recognizer_ptr->WaitDecodeFinished(); + + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + tot_attention_rescore_time += timer.Elapsed(); - recognizer_thread.join(); std::string result = recognizer_ptr->GetFinalResult(); - recognizer_ptr->Reset(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -107,6 +100,7 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur << " cost: " << local_timer.Elapsed(); @@ -115,9 +109,11 @@ int main(int argc, char* argv[]) { ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/cls/CMakeLists.txt b/speechx/speechx/cls/CMakeLists.txt new file mode 100644 index 000000000..99d123b4e --- /dev/null +++ b/speechx/speechx/cls/CMakeLists.txt @@ -0,0 +1,46 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(cls) + +set(ARCH "mserver_x86_64" CACHE STRING "Target Architecture: +android_arm, android_armv7, android_armv8, android_x86, android_x86_64, +mserver_x86_64, ubuntu_x86_64, +ios_armv7, ios_armv7s, ios_armv8, ios_x86_64, ios_x86, +windows_x86") + +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(FASTDEPLOY_DIR ${CMAKE_SOURCE_DIR}/fc_patch/fastdeploy) +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz) + exec_program("mkdir -p ${FASTDEPLOY_DIR} && + wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-1.0.2.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-linux-x64-1.0.2 ${FASTDEPLOY_DIR}/linux-x64") +endif() + +if(NOT EXISTS ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz) + exec_program("mkdir -p ${FASTDEPLOY_DIR} && + wget https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.0-shared.tgz -P ${FASTDEPLOY_DIR} && + tar xzvf ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared.tgz -C ${FASTDEPLOY_DIR} && + mv ${FASTDEPLOY_DIR}/fastdeploy-android-1.0.0-shared ${FASTDEPLOY_DIR}/android-armv7v8") +endif() + +if (ARCH STREQUAL "mserver_x86_64") + set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/linux-x64) + add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") + # add_definitions("-DUSE_ORT_BACKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") +elseif (ARCH STREQUAL "android_armv7") + set(FASTDEPLOY_INSTALL_DIR ${FASTDEPLOY_DIR}/android-armv7v8) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") +endif() + +# add_definitions("-DTEST_DEBUG") +# add_definitions("-DPRINT_TIME") + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(nnet) \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/CMakeLists.txt b/speechx/speechx/cls/nnet/CMakeLists.txt new file mode 100644 index 000000000..414c11e01 --- /dev/null +++ b/speechx/speechx/cls/nnet/CMakeLists.txt @@ -0,0 +1,12 @@ +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +set(srcs cls_nnet.cc cls_interface.cc) + +add_library(cls SHARED ${srcs}) +target_link_libraries(cls -static-libstdc++;-Wl,-Bsymbolic ${FASTDEPLOY_LIBS} kaldi-native-fbank-core) + +set(bin_name cls_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} -static-libstdc++;-Wl,-Bsymbolic cls) \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/cls_interface.cc b/speechx/speechx/cls/nnet/cls_interface.cc new file mode 100644 index 000000000..82f1404ec --- /dev/null +++ b/speechx/speechx/cls/nnet/cls_interface.cc @@ -0,0 +1,59 @@ +#include "cls/nnet/cls_nnet.h" +#include "cls/nnet/config.h" + +namespace ppspeech{ + +void* cls_create_instance(const char* conf_path){ + Config conf(conf_path); + //cls init + ppspeech::ClsNnetConf cls_nnet_conf; + cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true); + cls_nnet_conf.wav_normal_type_ = conf.Read("wav_normal_type", std::string("linear")); + cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0); + cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string("")); + cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string("")); + cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string("")); + cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12); + cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000); + cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32); + cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10); + cls_nnet_conf.num_bins = conf.Read("num_bins", 64); + cls_nnet_conf.low_freq = conf.Read("low_freq", 50); + cls_nnet_conf.high_freq = conf.Read("high_freq", 14000); + cls_nnet_conf.dither = conf.Read("dither", 0.0); + + ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet(); + int ret = cls_model->init(cls_nnet_conf); + return (void*)cls_model; +}; + +int cls_destroy_instance(void* instance){ + ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance; + if(cls_model != NULL){ + delete cls_model; + cls_model = NULL; + } + return 0; +}; + +int cls_feedforward(void* instance, const char* wav_path, int topk, char* result, int result_max_len){ + ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance; + if(cls_model == NULL){ + printf("instance is null\n"); + return -1; + } + int ret = cls_model->forward(wav_path, topk, result, result_max_len); + return 0; +}; + +int cls_reset(void* instance){ + ppspeech::ClsNnet* cls_model = (ppspeech::ClsNnet*)instance; + if(cls_model == NULL){ + printf("instance is null\n"); + return -1; + } + cls_model->reset(); + return 0; +}; + +} \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/cls_interface.h b/speechx/speechx/cls/nnet/cls_interface.h new file mode 100644 index 000000000..5c6256a6e --- /dev/null +++ b/speechx/speechx/cls/nnet/cls_interface.h @@ -0,0 +1,10 @@ +#pragma once + +namespace ppspeech{ + +void* cls_create_instance(const char* conf_path); +int cls_destroy_instance(void* instance); +int cls_feedforward(void* instance, const char* wav_path, int topk, char* result, int result_max_len); +int cls_reset(void* instance); + +} \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/cls_nnet.cc b/speechx/speechx/cls/nnet/cls_nnet.cc new file mode 100644 index 000000000..be56bf310 --- /dev/null +++ b/speechx/speechx/cls/nnet/cls_nnet.cc @@ -0,0 +1,405 @@ +#include "cls/nnet/cls_nnet.h" +#ifdef PRINT_TIME +#include +#endif + +namespace ppspeech { + +ClsNnet::ClsNnet(){ + // wav_reader_ = NULL; + runtime_ = NULL; +}; + +void ClsNnet::reset(){ + // wav_reader_->Clear(); + ss_.str(""); +}; + +int ClsNnet::init(ClsNnetConf& conf){ + conf_ = conf; + //init fbank opts + fbank_opts_.frame_opts.samp_freq = conf.samp_freq; + fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms; + fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms; + fbank_opts_.mel_opts.num_bins = conf.num_bins; + fbank_opts_.mel_opts.low_freq = conf.low_freq; + fbank_opts_.mel_opts.high_freq = conf.high_freq; + fbank_opts_.frame_opts.dither = conf.dither; + fbank_opts_.use_log_fbank = false; + + //init dict + if (conf.dict_file_path_ != ""){ + init_dict(conf.dict_file_path_); + } + + // init model + fastdeploy::RuntimeOption runtime_option; + +#ifdef USE_ORT_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx + runtime_option.UseOrtBackend(); // onnx +#endif +#ifdef USE_PADDLE_LITE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, conf.param_file_path_, fastdeploy::ModelFormat::PADDLE); + runtime_option.UseLiteBackend(); +#endif +#ifdef USE_PADDLE_INFERENCE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, conf.param_file_path_, fastdeploy::ModelFormat::PADDLE); + runtime_option.UsePaddleInferBackend(); +#endif + runtime_option.SetCpuThreadNum(conf.num_cpu_thread_); + runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass"); + runtime_ = std::unique_ptr(new fastdeploy::Runtime()); + if (!runtime_->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + } + + reset(); + return 0; +}; + +int ClsNnet::init_dict(std::string& dict_path){ + std::ifstream fp(dict_path); + std::string line = ""; + while(getline(fp, line)){ + dict_.push_back(line); + } + return 0; +}; + +int ClsNnet::forward(const char* wav_path, int topk, char* result, int result_max_len){ +#ifdef PRINT_TIME + double duration = 0; + std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now(); +#endif + //read wav + WavReader wav_reader(wav_path); + int wavform_len = wav_reader.num_samples(); + std::vector wavform(wavform_len); + memcpy(wavform.data(), wav_reader.data(), wavform_len * sizeof(float)); + waveform_float_normal(wavform); + waveform_normal(wavform, conf_.wav_normal_, conf_.wav_normal_type_, conf_.wav_norm_mul_factor_); +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.wavform", std::ios::out); + for (int i = 0; i < wavform.size(); ++i) { + fp << std::setprecision(18) << wavform[i] << " "; + } + fp << "\n"; + } +#endif +#ifdef PRINT_TIME + std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration(end_time - start_time).count(); + printf("wav read consume: %fs\n", duration / 1000000); +#endif + +#ifdef PRINT_TIME + start_time = std::chrono::high_resolution_clock::now(); +#endif + + std::vector feats; + std::unique_ptr data_source(new ppspeech::DataCache()); + ppspeech::Fbank fbank(fbank_opts_, std::move(data_source)); + fbank.Accept(wavform); + fbank.SetFinished(); + fbank.Read(&feats); + + int feat_dim = fbank_opts_.mel_opts.num_bins; + int num_frames = feats.size() / feat_dim; + + for (int i = 0; i < num_frames; ++i){ + for(int j = 0; j < feat_dim; ++j){ + feats[i * feat_dim + j] = power_to_db(feats[i * feat_dim + j]); + } + } +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.feat", std::ios::out); + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j){ + fp << std::setprecision(18) << feats[i * feat_dim + j] << " "; + } + fp << "\n"; + } + } +#endif +#ifdef PRINT_TIME + end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration(end_time - start_time).count(); + printf("extract fbank consume: %fs\n", duration / 1000000); +#endif + + // model_forward_stream(feats); + + //infer + std::vector model_out; +#ifdef PRINT_TIME + start_time = std::chrono::high_resolution_clock::now(); +#endif + model_forward(feats.data(), num_frames, feat_dim, model_out); +#ifdef PRINT_TIME + end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration(end_time - start_time).count(); + printf("fast deploy infer consume: %fs\n", duration / 1000000); +#endif +#ifdef TEST_DEBUG + { + std::ofstream fp("cls.logits", std::ios::out); + for (int i = 0; i < model_out.size(); ++i) { + fp << std::setprecision(18) << model_out[i] << "\n"; + } + } +#endif + + // construct result str + ss_ << "{"; + get_topk(topk, model_out); + ss_ << "}"; + + if (result_max_len <= ss_.str().size()){ + printf("result_max_len is short than result len\n"); + } + snprintf(result, result_max_len, "%s", ss_.str().c_str()); + return 0; +}; + +int ClsNnet::model_forward(const float* features, const int num_frames, const int feat_dim, std::vector& model_out){ + // init input tensor shape + fastdeploy::TensorInfo info = runtime_->GetInputInfo(0); + info.shape = {1, num_frames, feat_dim}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + input_tensors[0].SetExternalData({1, num_frames, feat_dim}, fastdeploy::FDDataType::FP32, (void*)features); + + //get input name + input_tensors[0].name = info.name; + + runtime_->Infer(input_tensors, &output_tensors); + + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out.resize(output_shape[0] * output_shape[1]); + memcpy((void*)model_out.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float)); + return 0; +}; + +int ClsNnet::model_forward_stream(std::vector& feats){ + // init input tensor shape + std::vector input_infos = runtime_->GetInputInfos(); + std::vector output_infos = runtime_->GetOutputInfos(); + + std::vector input_tensors(14); + std::vector output_tensors(13); + { + std::vector feats_tmp(feats.begin(), feats.begin() + 400 * 64); + std::vector flag({0}); + std::vector block1_conv1_cache(1 * 64 * 400 * 64, 0); + std::vector block1_conv2_cache(1 * 64 * 400 * 64, 0); + std::vector block2_conv1_cache(1 * 128 * 200 * 32, 0); + std::vector block2_conv2_cache(1 * 128 * 200 * 32, 0); + std::vector block3_conv1_cache(1 * 256 * 100 * 16, 0); + std::vector block3_conv2_cache(1 * 256 * 100 * 16, 0); + std::vector block4_conv1_cache(1 * 512 * 50 * 8, 0); + std::vector block4_conv2_cache(1 * 512 * 50 * 8, 0); + std::vector block5_conv1_cache(1 * 1024 * 25 * 4, 0); + std::vector block5_conv2_cache(1 * 1024 * 25 * 4, 0); + std::vector block6_conv1_cache(1 * 2048 * 12 * 2, 0); + std::vector block6_conv2_cache(1 * 2048 * 12 * 2, 0); + input_tensors[0].name = input_infos[0].name; + input_tensors[0].SetExternalData({1, 400, 64}, fastdeploy::FDDataType::FP32, (void*)feats_tmp.data()); + input_tensors[1].name = input_infos[1].name; + input_tensors[1].SetExternalData({1}, fastdeploy::FDDataType::INT32, (void*)flag.data()); + input_tensors[2].name = input_infos[2].name; + input_tensors[2].SetExternalData({1, 64, 400, 64}, fastdeploy::FDDataType::FP32, (void*)block1_conv1_cache.data()); + input_tensors[3].name = input_infos[3].name; + input_tensors[3].SetExternalData({1, 64, 400, 64}, fastdeploy::FDDataType::FP32, (void*)block1_conv2_cache.data()); + input_tensors[4].name = input_infos[4].name; + input_tensors[4].SetExternalData({1, 128, 200, 32}, fastdeploy::FDDataType::FP32, (void*)block2_conv1_cache.data()); + input_tensors[5].name = input_infos[5].name; + input_tensors[5].SetExternalData({1, 128, 200, 32}, fastdeploy::FDDataType::FP32, (void*)block2_conv2_cache.data()); + input_tensors[6].name = input_infos[6].name; + input_tensors[6].SetExternalData({1, 256, 100, 16}, fastdeploy::FDDataType::FP32, (void*)block3_conv1_cache.data()); + input_tensors[7].name = input_infos[7].name; + input_tensors[7].SetExternalData({1, 256, 100, 16}, fastdeploy::FDDataType::FP32, (void*)block3_conv2_cache.data()); + input_tensors[8].name = input_infos[8].name; + input_tensors[8].SetExternalData({1, 512, 50, 8}, fastdeploy::FDDataType::FP32, (void*)block4_conv1_cache.data()); + input_tensors[9].name = input_infos[9].name; + input_tensors[9].SetExternalData({1, 512, 50, 8}, fastdeploy::FDDataType::FP32, (void*)block4_conv2_cache.data()); + input_tensors[10].name = input_infos[10].name; + input_tensors[10].SetExternalData({1, 1024, 25, 4}, fastdeploy::FDDataType::FP32, (void*)block5_conv1_cache.data()); + input_tensors[11].name = input_infos[11].name; + input_tensors[11].SetExternalData({1, 1024, 25, 4}, fastdeploy::FDDataType::FP32, (void*)block5_conv2_cache.data()); + input_tensors[12].name = input_infos[12].name; + input_tensors[12].SetExternalData({1, 2048, 12, 2}, fastdeploy::FDDataType::FP32, (void*)block6_conv1_cache.data()); + input_tensors[13].name = input_infos[13].name; + input_tensors[13].SetExternalData({1, 2048, 12, 2}, fastdeploy::FDDataType::FP32, (void*)block6_conv2_cache.data()); + + std::vector model_out_tmp; +#ifdef PRINT_TIME + double duration = 0; + std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now(); +#endif + runtime_->Infer(input_tensors, &output_tensors); +#ifdef PRINT_TIME + std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration(end_time - start_time).count(); + printf("infer %d:%d consume: %fs\n", 0, 400, duration / 1000000); +#endif + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out_tmp.resize(output_shape[0] * output_shape[1]); + memcpy((void*)model_out_tmp.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float)); +#ifdef TEST_DEBUG + { + std::stringstream ss; + ss << "cls.logits." << 0 << "-" << 400; + std::ofstream fp(ss.str(), std::ios::out); + for (int i = 0; i < model_out_tmp.size(); ++i) { + fp << std::setprecision(18) << model_out_tmp[i] << "\n"; + } + } +#endif + } + { + std::vector feats_tmp(feats.begin() + 32 * 64, feats.begin() + 432 * 64); + std::vector flag({1}); + input_tensors[0].SetExternalData({1, 400, 64}, fastdeploy::FDDataType::FP32, (void*)feats_tmp.data()); + input_tensors[1].SetExternalData({1}, fastdeploy::FDDataType::INT32, (void*)flag.data()); + input_tensors[2] = output_tensors[1]; + input_tensors[2].name = input_infos[2].name; + input_tensors[3] = output_tensors[2]; + input_tensors[3].name = input_infos[3].name; + input_tensors[4] = output_tensors[3]; + input_tensors[4].name = input_infos[4].name; + input_tensors[5] = output_tensors[4]; + input_tensors[5].name = input_infos[5].name; + input_tensors[6] = output_tensors[5]; + input_tensors[6].name = input_infos[6].name; + input_tensors[7] = output_tensors[6]; + input_tensors[7].name = input_infos[7].name; + input_tensors[8] = output_tensors[7]; + input_tensors[8].name = input_infos[8].name; + input_tensors[9] = output_tensors[8]; + input_tensors[9].name = input_infos[9].name; + input_tensors[10] = output_tensors[9]; + input_tensors[10].name = input_infos[10].name; + input_tensors[11] = output_tensors[10]; + input_tensors[11].name = input_infos[11].name; + input_tensors[12] = output_tensors[11]; + input_tensors[12].name = input_infos[12].name; + input_tensors[13] = output_tensors[12]; + input_tensors[13].name = input_infos[13].name; + std::vector model_out_tmp; +#ifdef PRINT_TIME + double duration = 0; + std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now(); +#endif + runtime_->Infer(input_tensors, &output_tensors); +#ifdef PRINT_TIME + std::chrono::high_resolution_clock::time_point end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration(end_time - start_time).count(); + printf("infer %d:%d consume: %fs\n", 32, 432, duration / 1000000); +#endif + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out_tmp.resize(output_shape[0] * output_shape[1]); + memcpy((void*)model_out_tmp.data(), output_tensors[0].Data(), output_shape[0] * output_shape[1] * sizeof(float)); +#ifdef TEST_DEBUG + { + std::stringstream ss; + ss << "cls.logits." << 32 << "-" << 432; + std::ofstream fp(ss.str(), std::ios::out); + for (int i = 0; i < model_out_tmp.size(); ++i) { + fp << std::setprecision(18) << model_out_tmp[i] << "\n"; + } + } +#endif + } + exit(1); + return 0; +}; + +int ClsNnet::get_topk(int k, std::vector& model_out){ + std::vector> sort_vec; + for (int i = 0; i < model_out.size(); ++i){ + sort_vec.push_back({-1 * model_out[i], i}); + } + std::sort(sort_vec.begin(), sort_vec.end()); + for (int i = 0; i < k; ++i){ + if (i != 0){ + ss_ << ","; + } + ss_ << "\"" << dict_[sort_vec[i].second] << "\":\"" << -1 * sort_vec[i].first << "\""; + } + return 0; +}; + +int ClsNnet::waveform_float_normal(std::vector& waveform){ + int tot_samples = waveform.size(); + for (int i = 0; i < tot_samples; i++){ + waveform[i] = waveform[i] / 32768.0; + } + return 0; +} + +int ClsNnet::waveform_normal(std::vector& waveform, bool wav_normal, std::string& wav_normal_type, float wav_norm_mul_factor){ + if (wav_normal == false){ + return 0; + } + if (wav_normal_type == "linear"){ + float amax = INT32_MIN; + for (int i = 0; i < waveform.size(); ++i){ + float tmp = std::abs(waveform[i]); + amax = std::max(amax, tmp); + } + float factor = 1.0 / (amax + 1e-8); + for (int i = 0; i < waveform.size(); ++i){ + waveform[i] = waveform[i] * factor * wav_norm_mul_factor; + } + } else if (wav_normal_type == "gaussian") { + double sum = std::accumulate(waveform.begin(), waveform.end(), 0.0); + double mean = sum / waveform.size(); //均值 + + double accum = 0.0; + std::for_each (waveform.begin(), waveform.end(), [&](const double d) { + accum += (d-mean)*(d-mean); + }); + + double stdev = sqrt(accum/(waveform.size()-1)); //方差 + stdev = std::max(stdev, 1e-8); + + for (int i = 0; i < waveform.size(); ++i){ + waveform[i] = wav_norm_mul_factor * (waveform[i] - mean) / stdev; + } + } else { + printf("don't support\n"); + return -1; + } + return 0; +} + +float ClsNnet::power_to_db(float in, float ref_value, float amin, float top_db){ + if(amin <= 0){ + printf("amin must be strictly positive\n"); + return -1; + }; + + if(ref_value <= 0){ + printf("ref_value must be strictly positive\n"); + return -1; + } + + float out = 10.0 * log10(std::max(amin, in)); + out -= 10.0 * log10(std::max(ref_value, amin)); + return out; +} + +} \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/cls_nnet.h b/speechx/speechx/cls/nnet/cls_nnet.h new file mode 100644 index 000000000..7fefb5811 --- /dev/null +++ b/speechx/speechx/cls/nnet/cls_nnet.h @@ -0,0 +1,73 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from +// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.h +#pragma once + +#include +#include +#include +#include "fastdeploy/runtime.h" +#include "cls/nnet/wav.h" +#include "common/frontend/frontend_itf.h" +#include "common/frontend/data_cache.h" +#include "common/frontend/feature-fbank.h" +#include "frontend/fbank.h" + +namespace ppspeech { +struct ClsNnetConf { + //wav + bool wav_normal_; + std::string wav_normal_type_; + float wav_norm_mul_factor_; + //model + std::string model_file_path_; + std::string param_file_path_; + std::string dict_file_path_; + int num_cpu_thread_; + //fbank + float samp_freq; + float frame_length_ms; + float frame_shift_ms; + int num_bins; + float low_freq; + float high_freq; + float dither; +}; + +class ClsNnet { +public: + ClsNnet(); + int init(ClsNnetConf& conf); + int forward(const char* wav_path, int topk, char* result, int result_max_len); + void reset(); +private: + int init_dict(std::string& dict_path); + int model_forward(const float* features, const int num_frames, const int feat_dim, std::vector& model_out); + int model_forward_stream(std::vector& feats); + int get_topk(int k, std::vector& model_out); + int waveform_float_normal(std::vector& waveform); + int waveform_normal(std::vector& waveform, bool wav_normal, std::string& wav_normal_type, float wav_norm_mul_factor); + float power_to_db(float in, float ref_value = 1.0, float amin = 1e-10, float top_db = 80.0); + + ClsNnetConf conf_; + knf::FbankOptions fbank_opts_; + std::unique_ptr runtime_; + std::vector dict_; + std::stringstream ss_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/cls/nnet/cls_nnet_main.cc b/speechx/speechx/cls/nnet/cls_nnet_main.cc new file mode 100644 index 000000000..24669279a --- /dev/null +++ b/speechx/speechx/cls/nnet/cls_nnet_main.cc @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "nnet/cls_interface.h" + +int main(int argc, char* argv[]) { + if (argc != 4){ + printf("usage : cls_nnet_main conf_path wav_path topk\n"); + return 0; + } + const char* conf_path = argv[1]; + const char* wav_path = argv[2]; + int topk = std::atoi(argv[3]); + void* instance = ppspeech::cls_create_instance(conf_path); + int ret = 0; + //read wav + std::ifstream ifs(wav_path); + std::string line = ""; + while(getline(ifs, line)){ + //read wav + char result[1024] = {0}; + ret = ppspeech::cls_feedforward(instance, line.c_str(), topk, result, 1024); + printf("%s %s\n", line.c_str(), result); + ret = ppspeech::cls_reset(instance); + } + ret = ppspeech::cls_destroy_instance(instance); + return 0; +} diff --git a/speechx/speechx/cls/nnet/config.h b/speechx/speechx/cls/nnet/config.h new file mode 100644 index 000000000..c566893cc --- /dev/null +++ b/speechx/speechx/cls/nnet/config.h @@ -0,0 +1,352 @@ +#include +#include +#include +#include +#include +using namespace std; + +#pragma once + +#pragma region ParseIniFile +/* +* \brief Generic configuration Class +* +*/ +class Config { + // Data +protected: + std::string m_Delimiter; //!< separator between key and value + std::string m_Comment; //!< separator between value and comments + std::map m_Contents; //!< extracted keys and values + + typedef std::map::iterator mapi; + typedef std::map::const_iterator mapci; + // Methods +public: + + Config(std::string filename, std::string delimiter = "=", std::string comment = "#"); + Config(); + template T Read(const std::string& in_key) const; //! + template T Read(const std::string& in_key, const T& in_value) const; + template bool ReadInto(T& out_var, const std::string& in_key) const; + template + bool ReadInto(T& out_var, const std::string& in_key, const T& in_value) const; + bool FileExist(std::string filename); + void ReadFile(std::string filename, std::string delimiter = "=", std::string comment = "#"); + + // Check whether key exists in configuration + bool KeyExists(const std::string& in_key) const; + + // Modify keys and values + template void Add(const std::string& in_key, const T& in_value); + void Remove(const std::string& in_key); + + // Check or change configuration syntax + std::string GetDelimiter() const { return m_Delimiter; } + std::string GetComment() const { return m_Comment; } + std::string SetDelimiter(const std::string& in_s) + { + std::string old = m_Delimiter; m_Delimiter = in_s; return old; + } + std::string SetComment(const std::string& in_s) + { + std::string old = m_Comment; m_Comment = in_s; return old; + } + + // Write or read configuration + friend std::ostream& operator<<(std::ostream& os, const Config& cf); + friend std::istream& operator >> (std::istream& is, Config& cf); + +protected: + template static std::string T_as_string(const T& t); + template static T string_as_T(const std::string& s); + static void Trim(std::string& inout_s); + + + // Exception types +public: + struct File_not_found { + std::string filename; + File_not_found(const std::string& filename_ = std::string()) + : filename(filename_) {} + }; + struct Key_not_found { // thrown only by T read(key) variant of read() + std::string key; + Key_not_found(const std::string& key_ = std::string()) + : key(key_) {} + }; +}; + +/* static */ +template +std::string Config::T_as_string(const T& t) +{ + // Convert from a T to a string + // Type T must support << operator + std::ostringstream ost; + ost << t; + return ost.str(); +} + + +/* static */ +template +T Config::string_as_T(const std::string& s) +{ + // Convert from a string to a T + // Type T must support >> operator + T t; + std::istringstream ist(s); + ist >> t; + return t; +} + + +/* static */ +template<> +inline std::string Config::string_as_T(const std::string& s) +{ + // Convert from a string to a string + // In other words, do nothing + return s; +} + + +/* static */ +template<> +inline bool Config::string_as_T(const std::string& s) +{ + // Convert from a string to a bool + // Interpret "false", "F", "no", "n", "0" as false + // Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true + bool b = true; + std::string sup = s; + for (std::string::iterator p = sup.begin(); p != sup.end(); ++p) + *p = toupper(*p); // make string all caps + if (sup == std::string("FALSE") || sup == std::string("F") || + sup == std::string("NO") || sup == std::string("N") || + sup == std::string("0") || sup == std::string("NONE")) + b = false; + return b; +} + + +template +T Config::Read(const std::string& key) const +{ + // Read the value corresponding to key + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) throw Key_not_found(key); + return string_as_T(p->second); +} + + +template +T Config::Read(const std::string& key, const T& value) const +{ + // Return the value corresponding to key or given default value + // if key is not found + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) { + printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str()); + return value; + } else { + printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str()); + return string_as_T(p->second); + } + +} + + +template +bool Config::ReadInto(T& var, const std::string& key) const +{ + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise leave var untouched + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) var = string_as_T(p->second); + return found; +} + + +template +bool Config::ReadInto(T& var, const std::string& key, const T& value) const +{ + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise set var to given default + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) + var = string_as_T(p->second); + else + var = value; + return found; +} + + +template +void Config::Add(const std::string& in_key, const T& value) +{ + // Add a key with given value + std::string v = T_as_string(value); + std::string key = in_key; + Trim(key); + Trim(v); + m_Contents[key] = v; + return; +} + +Config::Config(string filename, string delimiter, + string comment) + : m_Delimiter(delimiter), m_Comment(comment) +{ + // Construct a Config, getting keys and values from given file + + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + + +Config::Config() + : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) +{ + // Construct a Config without a file; empty +} + + + +bool Config::KeyExists(const string& key) const +{ + // Indicate whether key is found + mapci p = m_Contents.find(key); + return (p != m_Contents.end()); +} + + +/* static */ +void Config::Trim(string& inout_s) +{ + // Remove leading and trailing whitespace + static const char whitespace[] = " \n\t\v\r\f"; + inout_s.erase(0, inout_s.find_first_not_of(whitespace)); + inout_s.erase(inout_s.find_last_not_of(whitespace) + 1U); +} + + +std::ostream& operator<<(std::ostream& os, const Config& cf) +{ + // Save a Config to os + for (Config::mapci p = cf.m_Contents.begin(); + p != cf.m_Contents.end(); + ++p) + { + os << p->first << " " << cf.m_Delimiter << " "; + os << p->second << std::endl; + } + return os; +} + +void Config::Remove(const string& key) +{ + // Remove key and its value + m_Contents.erase(m_Contents.find(key)); + return; +} + +std::istream& operator >> (std::istream& is, Config& cf) +{ + // Load a Config from is + // Read in keys and values, keeping internal whitespace + typedef string::size_type pos; + const string& delim = cf.m_Delimiter; // separator + const string& comm = cf.m_Comment; // comment + const pos skip = delim.length(); // length of separator + + string nextline = ""; // might need to read ahead to see where value ends + + while (is || nextline.length() > 0) + { + // Read an entire line at a time + string line; + if (nextline.length() > 0) + { + line = nextline; // we read ahead; use it now + nextline = ""; + } + else + { + std::getline(is, line); + } + + // Ignore comments + line = line.substr(0, line.find(comm)); + + // Parse the line if it contains a delimiter + pos delimPos = line.find(delim); + if (delimPos < string::npos) + { + // Extract the key + string key = line.substr(0, delimPos); + line.replace(0, delimPos + skip, ""); + + // See if value continues on the next line + // Stop at blank line, next line with a key, end of stream, + // or end of file sentry + bool terminate = false; + while (!terminate && is) + { + std::getline(is, nextline); + terminate = true; + + string nlcopy = nextline; + Config::Trim(nlcopy); + if (nlcopy == "") continue; + + nextline = nextline.substr(0, nextline.find(comm)); + if (nextline.find(delim) != string::npos) + continue; + + nlcopy = nextline; + Config::Trim(nlcopy); + if (nlcopy != "") line += "\n"; + line += nextline; + terminate = false; + } + + // Store key and value + Config::Trim(key); + Config::Trim(line); + cf.m_Contents[key] = line; // overwrites if key is repeated + } + } + + return is; +} +bool Config::FileExist(std::string filename) +{ + bool exist = false; + std::ifstream in(filename.c_str()); + if (in) + exist = true; + return exist; +} + +void Config::ReadFile(string filename, string delimiter, + string comment) +{ + m_Delimiter = delimiter; + m_Comment = comment; + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + +#pragma endregion ParseIniFIle diff --git a/speechx/speechx/cls/nnet/wav.h b/speechx/speechx/cls/nnet/wav.h new file mode 100644 index 000000000..841dc2283 --- /dev/null +++ b/speechx/speechx/cls/nnet/wav.h @@ -0,0 +1,241 @@ +// Copyright (c) 2016 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef FRONTEND_WAV_H_ +#define FRONTEND_WAV_H_ + +#include +#include +#include +#include +#include + +#include + +#pragma once + +namespace ppspeech { + +struct WavHeader { + char riff[4] = {'R', 'I', 'F', 'F'}; + unsigned int size = 0; + char wav[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + unsigned int fmt_size = 16; + uint16_t format = 1; + uint16_t channels = 0; + unsigned int sample_rate = 0; + unsigned int bytes_per_second = 0; + uint16_t block_size = 0; + uint16_t bit = 0; + char data[4] = {'d', 'a', 't', 'a'}; + unsigned int data_size = 0; + + WavHeader() {} + + WavHeader(int num_samples, int num_channel, int sample_rate, + int bits_per_sample) { + data_size = num_samples * num_channel * (bits_per_sample / 8); + size = sizeof(WavHeader) - 8 + data_size; + channels = num_channel; + this->sample_rate = sample_rate; + bytes_per_second = sample_rate * num_channel * (bits_per_sample / 8); + block_size = num_channel * (bits_per_sample / 8); + bit = bits_per_sample; + } +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (NULL == fp) { + //LOG(WARNING) << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "RIFF" "WAVE" "fmt " "data" + + // Skip any sub-chunks between "fmt" and "data". Usually there will + // be a single "fact" sub chunk, but on Windows there can also be a + // "list" sub chunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next sub chunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + data_[i] = static_cast(sample); + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + + ~WavReader() { + delete[] data_; + } + + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, int num_samples, int num_channel, + int sample_rate, int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + WavHeader header(num_samples_, num_channel_, sample_rate_, + bits_per_sample_); + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +class StreamWavWriter { + public: + StreamWavWriter(int num_channel, int sample_rate, int bits_per_sample) + : num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample), + total_num_samples_(0) {} + + StreamWavWriter(const std::string& filename, int num_channel, + int sample_rate, int bits_per_sample) + : StreamWavWriter(num_channel, sample_rate, bits_per_sample) { + Open(filename); + } + + void Open(const std::string& filename) { + fp_ = fopen(filename.c_str(), "wb"); + fseek(fp_, sizeof(WavHeader), SEEK_SET); + } + + void Write(const int16_t* sample_data, size_t num_samples) { + fwrite(sample_data, sizeof(int16_t), num_samples, fp_); + total_num_samples_ += num_samples; + } + + void Close() { + WavHeader header(total_num_samples_, num_channel_, sample_rate_, + bits_per_sample_); + fseek(fp_, 0L, SEEK_SET); + fwrite(&header, 1, sizeof(header), fp_); + fclose(fp_); + } + + private: + FILE* fp_; + int num_channel_; + int sample_rate_; + int bits_per_sample_; + size_t total_num_samples_; +}; + +} // namespace wenet + +#endif // FRONTEND_WAV_H_ diff --git a/speechx/speechx/common/frontend/compute_fbank_main.cc b/speechx/speechx/common/frontend/compute_fbank_main.cc index d7d5165ca..e022207d9 100644 --- a/speechx/speechx/common/frontend/compute_fbank_main.cc +++ b/speechx/speechx/common/frontend/compute_fbank_main.cc @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); diff --git a/speechx/speechx/common/frontend/feature_cache.cc b/speechx/speechx/common/frontend/feature_cache.cc index e6ac3c23c..c166bd64b 100644 --- a/speechx/speechx/common/frontend/feature_cache.cc +++ b/speechx/speechx/common/frontend/feature_cache.cc @@ -20,10 +20,9 @@ using kaldi::BaseFloat; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } @@ -31,34 +30,25 @@ FeatureCache::FeatureCache(FeatureCacheOptions opts, void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms + // feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + if (cache_.empty()) return false; // read from cache *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; return true; } @@ -73,23 +63,15 @@ bool FeatureCache::Compute() { kaldi::Timer timer; int32 num_chunk = feature.size() / dim_; - nframe_ += num_chunk; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; vector feature_chunk(feature.data() + start, feature.data() + start + dim_); - - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - // cache full, wait - ready_feed_condition_.wait(lock); - } - // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; } VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " @@ -97,4 +79,4 @@ bool FeatureCache::Compute() { return true; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/common/frontend/feature_cache.h b/speechx/speechx/common/frontend/feature_cache.h index 51816a1de..b87612d66 100644 --- a/speechx/speechx/common/frontend/feature_cache.h +++ b/speechx/speechx/common/frontend/feature_cache.h @@ -19,16 +19,10 @@ namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -41,13 +35,11 @@ class FeatureCache : public FrontendInterface { virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + std::unique_lock lock(mutex_); LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); - base_extractor_->SetFinished(); - // read the last chunk data Compute(); - // ready_feed_condition_.notify_one(); + base_extractor_->SetFinished(); LOG(INFO) << "compute last feats done."; } @@ -66,16 +58,10 @@ class FeatureCache : public FrontendInterface { int32 dim_; size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms - std::vector remained_feature_; std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache); diff --git a/speechx/speechx/common/frontend/feature_pipeline.cc b/speechx/speechx/common/frontend/feature_pipeline.cc index 34e55a10c..f37b41807 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.cc +++ b/speechx/speechx/common/frontend/feature_pipeline.cc @@ -33,7 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/common/frontend/feature_pipeline.h b/speechx/speechx/common/frontend/feature_pipeline.h index ea7e2bba3..c9a649fde 100644 --- a/speechx/speechx/common/frontend/feature_pipeline.h +++ b/speechx/speechx/common/frontend/feature_pipeline.h @@ -39,7 +39,6 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; knf::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() {