From 6193064bd5afc6792ba7bf40aaffa5c7deff0c18 Mon Sep 17 00:00:00 2001 From: YangZhou Date: Wed, 22 Mar 2023 20:25:42 +0800 Subject: [PATCH] add recoginzer controller --- runtime/cmake/openfst.cmake | 2 +- runtime/engine/CMakeLists.txt | 2 +- runtime/engine/asr/decoder/CMakeLists.txt | 2 +- runtime/engine/asr/recognizer/CMakeLists.txt | 4 +- .../asr/recognizer/recognizer_batch_main.cc | 175 ++++++++++++++++++ .../asr/recognizer/recognizer_controller.cc | 71 +++++++ .../asr/recognizer/recognizer_controller.h | 39 ++++ 7 files changed, 291 insertions(+), 4 deletions(-) create mode 100644 runtime/engine/asr/recognizer/recognizer_batch_main.cc create mode 100644 runtime/engine/asr/recognizer/recognizer_controller.cc create mode 100644 runtime/engine/asr/recognizer/recognizer_controller.h diff --git a/runtime/cmake/openfst.cmake b/runtime/cmake/openfst.cmake index a859076fe..42299c88c 100644 --- a/runtime/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -30,7 +30,7 @@ ExternalProject_Add(openfst CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" - "LIBS=-lgflags_nothreads -lglog -lpthread" + "LIBS=-lgflags_nothreads -lglog -lpthread -fPIC" COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} BUILD_COMMAND make -j 4 ) diff --git a/runtime/engine/CMakeLists.txt b/runtime/engine/CMakeLists.txt index 242a579b5..a1a92cd72 100644 --- a/runtime/engine/CMakeLists.txt +++ b/runtime/engine/CMakeLists.txt @@ -21,4 +21,4 @@ if(WITH_VAD) add_subdirectory(vad) endif() -add_subdirectory(codelab) \ No newline at end of file +add_subdirectory(codelab) diff --git a/runtime/engine/asr/decoder/CMakeLists.txt b/runtime/engine/asr/decoder/CMakeLists.txt index 07adda956..0cd9fc48a 100644 --- a/runtime/engine/asr/decoder/CMakeLists.txt +++ b/runtime/engine/asr/decoder/CMakeLists.txt @@ -19,6 +19,6 @@ foreach(bin_name IN LISTS TEST_BINS) target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) endforeach() diff --git a/runtime/engine/asr/recognizer/CMakeLists.txt b/runtime/engine/asr/recognizer/CMakeLists.txt index f28c5fea8..54cf95e42 100644 --- a/runtime/engine/asr/recognizer/CMakeLists.txt +++ b/runtime/engine/asr/recognizer/CMakeLists.txt @@ -2,6 +2,7 @@ set(srcs) list(APPEND srcs u2_recognizer.cc + recognizer_controller.cc ) add_library(recognizer STATIC ${srcs}) @@ -11,6 +12,7 @@ set(TEST_BINS u2_recognizer_main u2_recognizer_thread_main u2_recognizer_batch_main + recognizer_batch_main ) foreach(bin_name IN LISTS TEST_BINS) @@ -19,5 +21,5 @@ foreach(bin_name IN LISTS TEST_BINS) target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) - target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) endforeach() diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main.cc b/runtime/engine/asr/recognizer/recognizer_batch_main.cc new file mode 100644 index 000000000..6f0f6cbcd --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/u2_recognizer.h" +#include "recognizer/recognizer_controller.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(ppspeech::RecognizerController* recognizer_controller, + std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id != -1) { + recog_id = recognizer_controller->GetRecognizerInstanceId(); + } + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + recognizer_controller->Accept(wav_chunk, recog_id); + if (cur_chunk_size < chunk_sample_size) { + recognizer_controller->SetInputFinished(recog_id); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + std::string result = recognizer_controller->GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::U2RecognizerResource resource = + ppspeech::U2RecognizerResource::InitFromFlags(); + ppspeech::RecognizerController recognizer_controller(njob, resource); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + &recognizer_controller, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.cc b/runtime/engine/asr/recognizer/recognizer_controller.cc new file mode 100644 index 000000000..b77381694 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_controller.h" +#include "recognizer/u2_recognizer.h" +#include "nnet/u2_nnet.h" + +namespace ppspeech { + +RecognizerController::RecognizerController(int num_worker, U2RecognizerResource resource) { + nnet_ = std::make_shared(resource.model_opts); + recognizer_workers.resize(num_worker); + for (size_t i = 0; i < num_worker; ++i) { + recognizer_workers[i].reset(new ppspeech::U2Recognizer(resource, nnet_->Clone())); + recognizer_workers[i]->InitDecoder(); + waiting_workers.push(i); + } +} + +int RecognizerController::GetRecognizerInstanceId() { + if (waiting_workers.empty()) { + return -1; + } + int idx = -1; + { + std::unique_lock lock(mutex_); + idx = waiting_workers.front(); + waiting_workers.pop(); + } + return idx; +} + +RecognizerController::~RecognizerController() { + for (size_t i = 0; i < recognizer_workers.size(); ++i) { + recognizer_workers[i]->SetInputFinished(); + recognizer_workers[i]->WaitDecodeFinished(); + } +} + +std::string RecognizerController::GetFinalResult(int idx) { + recognizer_workers[idx]->WaitDecodeFinished(); + recognizer_workers[idx]->AttentionRescoring(); + std::string result = recognizer_workers[idx]->GetFinalResult(); + recognizer_workers[idx]->InitDecoder(); + { + std::unique_lock lock(mutex_); + waiting_workers.push(idx); + } + return result; +} + +void RecognizerController::Accept(std::vector data, int idx) { + recognizer_workers[idx]->Accept(data); +} + +void RecognizerController::SetInputFinished(int idx) { + recognizer_workers[idx]->SetInputFinished(); +} + +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_controller.h b/runtime/engine/asr/recognizer/recognizer_controller.h new file mode 100644 index 000000000..94a434121 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "recognizer/u2_recognizer.h" +#include "nnet/u2_nnet.h" + +namespace ppspeech { + +class RecognizerController { + public: + explicit RecognizerController(int num_worker, U2RecognizerResource resource); + ~RecognizerController(); + int GetRecognizerInstanceId(); + void Accept(std::vector data, int idx); + void SetInputFinished(int idx); + std::string GetFinalResult(int idx); + + private: + std::queue waiting_workers; + std::shared_ptr nnet_; + std::mutex mutex_; + std::vector> recognizer_workers; +}; + +} \ No newline at end of file